Skip to content

Commit 0001522

Browse files
committed
Re-order decay/no-decay groups to match old order and pass existing test. Change end lr decay test to min scale.
1 parent bfd490e commit 0001522

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

tests/test_optim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def test_lookahead_radam(optimizer):
552552
)
553553

554554

555-
def test_param_groups_layer_decay_with_end_decay():
555+
def test_param_groups_layer_decay_with_min():
556556
model = torch.nn.Sequential(
557557
torch.nn.Linear(10, 5),
558558
torch.nn.ReLU(),
@@ -563,12 +563,12 @@ def test_param_groups_layer_decay_with_end_decay():
563563
model,
564564
weight_decay=0.05,
565565
layer_decay=0.75,
566-
end_layer_decay=0.5,
566+
min_scale=0.5,
567567
verbose=True
568568
)
569569

570570
assert len(param_groups) > 0
571-
# Verify layer scaling is applied with end decay
571+
# Verify layer scaling is applied with a min scale
572572
for group in param_groups:
573573
assert 'lr_scale' in group
574574
assert group['lr_scale'] <= 1.0

timm/optim/_param_groups.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ def param_groups_weight_decay(
4949
decay.append(param)
5050

5151
groups = []
52-
if decay:
53-
groups.append({'params': decay, 'weight_decay': weight_decay})
54-
if decay_simple:
55-
groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True})
5652
if no_decay:
5753
groups.append({'params': no_decay, 'weight_decay': 0.})
54+
if decay:
55+
groups.append({'params': decay, 'weight_decay': weight_decay})
5856
if no_decay_simple:
5957
groups.append({'params': no_decay_simple, 'weight_decay': 0., 'simple': True})
58+
if decay_simple:
59+
groups.append({'params': decay_simple, 'weight_decay': weight_decay, 'simple': True})
6060

6161
return groups
6262

0 commit comments

Comments
 (0)