11import math
2+ import torch
23import torch .optim as optim
34
45class SharedAdam (optim .Adam ):
@@ -12,13 +13,53 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
1213 for group in self .param_groups :
1314 for p in group ['params' ]:
1415 state = self .state [p ]
15- state ['step' ] = 0
16+ state ['step' ] = torch . zeros ( 1 )
1617 state ['exp_avg' ] = p .data .new ().resize_as_ (p .data ).zero_ ()
1718 state ['exp_avg_sq' ] = p .data .new ().resize_as_ (p .data ).zero_ ()
1819
1920 def share_memory (self ):
2021 for group in self .param_groups :
2122 for p in group ['params' ]:
2223 state = self .state [p ]
24+ state ['step' ].share_memory_ ()
2325 state ['exp_avg' ].share_memory_ ()
2426 state ['exp_avg_sq' ].share_memory_ ()
27+
28+ def step (self , closure = None ):
29+ """Performs a single optimization step.
30+ Arguments:
31+ closure (callable, optional): A closure that reevaluates the model
32+ and returns the loss.
33+ """
34+ loss = None
35+ if closure is not None :
36+ loss = closure ()
37+
38+ for group in self .param_groups :
39+ for p in group ['params' ]:
40+ if p .grad is None :
41+ continue
42+ grad = p .grad .data
43+ state = self .state [p ]
44+
45+ exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
46+ beta1 , beta2 = group ['betas' ]
47+
48+ state ['step' ] += 1
49+
50+ if group ['weight_decay' ] != 0 :
51+ grad = grad .add (group ['weight_decay' ], p .data )
52+
53+ # Decay the first and second moment running average coefficient
54+ exp_avg .mul_ (beta1 ).add_ (1 - beta1 , grad )
55+ exp_avg_sq .mul_ (beta2 ).addcmul_ (1 - beta2 , grad , grad )
56+
57+ denom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
58+
59+ bias_correction1 = 1 - beta1 ** state ['step' ][0 ]
60+ bias_correction2 = 1 - beta2 ** state ['step' ][0 ]
61+ step_size = group ['lr' ] * math .sqrt (bias_correction2 ) / bias_correction1
62+
63+ p .data .addcdiv_ (- step_size , exp_avg , denom )
64+
65+ return loss
0 commit comments