Skip to content

Commit 5d9b07d

Browse files
committed
Fix a problem with shared adam
1 parent e19ac39 commit 5d9b07d

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

my_optim.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import torch
23
import torch.optim as optim
34

45
class SharedAdam(optim.Adam):
@@ -12,13 +13,53 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
1213
for group in self.param_groups:
1314
for p in group['params']:
1415
state = self.state[p]
15-
state['step'] = 0
16+
state['step'] = torch.zeros(1)
1617
state['exp_avg'] = p.data.new().resize_as_(p.data).zero_()
1718
state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_()
1819

1920
def share_memory(self):
2021
for group in self.param_groups:
2122
for p in group['params']:
2223
state = self.state[p]
24+
state['step'].share_memory_()
2325
state['exp_avg'].share_memory_()
2426
state['exp_avg_sq'].share_memory_()
27+
28+
def step(self, closure=None):
29+
"""Performs a single optimization step.
30+
Arguments:
31+
closure (callable, optional): A closure that reevaluates the model
32+
and returns the loss.
33+
"""
34+
loss = None
35+
if closure is not None:
36+
loss = closure()
37+
38+
for group in self.param_groups:
39+
for p in group['params']:
40+
if p.grad is None:
41+
continue
42+
grad = p.grad.data
43+
state = self.state[p]
44+
45+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
46+
beta1, beta2 = group['betas']
47+
48+
state['step'] += 1
49+
50+
if group['weight_decay'] != 0:
51+
grad = grad.add(group['weight_decay'], p.data)
52+
53+
# Decay the first and second moment running average coefficient
54+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
55+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
56+
57+
denom = exp_avg_sq.sqrt().add_(group['eps'])
58+
59+
bias_correction1 = 1 - beta1 ** state['step'][0]
60+
bias_correction2 = 1 - beta2 ** state['step'][0]
61+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
62+
63+
p.data.addcdiv_(-step_size, exp_avg, denom)
64+
65+
return loss

train.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ def train(rank, args, shared_model, optimizer=None):
3131

3232
model.train()
3333

34-
values = []
35-
log_probs = []
36-
3734
state = env.reset()
3835
state = torch.from_numpy(state)
3936
done = True

0 commit comments

Comments
 (0)