File tree Expand file tree Collapse file tree 2 files changed +7
-7
lines changed Expand file tree Collapse file tree 2 files changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -552,7 +552,7 @@ def test_lookahead_radam(optimizer):
552552 )
553553
554554
555- def test_param_groups_layer_decay_with_end_decay ():
555+ def test_param_groups_layer_decay_with_min ():
556556 model = torch .nn .Sequential (
557557 torch .nn .Linear (10 , 5 ),
558558 torch .nn .ReLU (),
@@ -563,12 +563,12 @@ def test_param_groups_layer_decay_with_end_decay():
563563 model ,
564564 weight_decay = 0.05 ,
565565 layer_decay = 0.75 ,
566- end_layer_decay = 0.5 ,
566+ min_scale = 0.5 ,
567567 verbose = True
568568 )
569569
570570 assert len (param_groups ) > 0
571- # Verify layer scaling is applied with end decay
571+ # Verify layer scaling is applied with a min scale
572572 for group in param_groups :
573573 assert 'lr_scale' in group
574574 assert group ['lr_scale' ] <= 1.0
Original file line number Diff line number Diff line change @@ -49,14 +49,14 @@ def param_groups_weight_decay(
4949 decay .append (param )
5050
5151 groups = []
52- if decay :
53- groups .append ({'params' : decay , 'weight_decay' : weight_decay })
54- if decay_simple :
55- groups .append ({'params' : decay_simple , 'weight_decay' : weight_decay , 'simple' : True })
5652 if no_decay :
5753 groups .append ({'params' : no_decay , 'weight_decay' : 0. })
54+ if decay :
55+ groups .append ({'params' : decay , 'weight_decay' : weight_decay })
5856 if no_decay_simple :
5957 groups .append ({'params' : no_decay_simple , 'weight_decay' : 0. , 'simple' : True })
58+ if decay_simple :
59+ groups .append ({'params' : decay_simple , 'weight_decay' : weight_decay , 'simple' : True })
6060
6161 return groups
6262
You can’t perform that action at this time.
0 commit comments