Skip to content

Commit 85c7efd

Browse files
committed
Add an optimizer with shared statistics
1 parent b0c1560 commit 85c7efd

File tree

4 files changed

+40
-4
lines changed

4 files changed

+40
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This is a PyTorch implementation of Asynchronous Advantage Actor Critic (A3C) from ["Asynchronous Methods for Deep Reinforcement Learning"](https://arxiv.org/pdf/1602.01783v1.pdf).
44

55
This implementation is inspired by [Universe Starter Agent](https://github.com/openai/universe-starter-agent).
6-
As in the starter agent, I don't share parameters of the optimizers between threads. If you want to have the same optimizer as in the original paper by DeepMind, you might want to check [this implementation.](https://github.com/rarilurelo/pytorch_a3c)
6+
In contrast to the starter agent, it uses an optimizer with shared statistics as in the original paper.
77

88
## Contibutions
99

main.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import sys
66

77
import torch
8+
import torch.optim as optim
89
import torch.multiprocessing as mp
910
import torch.nn as nn
1011
import torch.nn.functional as F
1112
from envs import create_atari_env
1213
from model import ActorCritic
1314
from train import train
1415
from test import test
16+
import my_optim
17+
1518
# Based on
1619
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
1720
# Training settings
@@ -32,6 +35,8 @@
3235
help='maximum length of an episode (default: 10000)')
3336
parser.add_argument('--env-name', default='PongDeterministic-v3', metavar='ENV',
3437
help='environment to train on (default: PongDeterministic-v3)')
38+
parser.add_argument('--no-shared', default=False, metavar='O',
39+
help='use an optimizer without shared momentum.')
3540

3641

3742
if __name__ == '__main__':
@@ -44,14 +49,20 @@
4449
env.observation_space.shape[0], env.action_space)
4550
shared_model.share_memory()
4651

52+
if args.no_shared:
53+
optimizer = None
54+
else:
55+
optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr)
56+
optimizer.share_memory()
57+
4758
processes = []
4859

4960
p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
5061
p.start()
5162
processes.append(p)
5263

5364
for rank in range(0, args.num_processes):
54-
p = mp.Process(target=train, args=(rank, args, shared_model))
65+
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
5566
p.start()
5667
processes.append(p)
5768
for p in processes:

my_optim.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import math
2+
import torch.optim as optim
3+
4+
class SharedAdam(optim.Adam):
5+
"""Implements Adam algorithm with shared states.
6+
"""
7+
8+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
9+
weight_decay=0):
10+
super(SharedAdam, self).__init__(params, lr, betas, eps, weight_decay)
11+
12+
for group in self.param_groups:
13+
for p in group['params']:
14+
state = self.state[p]
15+
state['step'] = 0
16+
state['exp_avg'] = p.data.new().resize_as_(p.data).zero_()
17+
state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_()
18+
19+
def share_memory(self):
20+
for group in self.param_groups:
21+
for p in group['params']:
22+
state = self.state[p]
23+
state['exp_avg'].share_memory_()
24+
state['exp_avg_sq'].share_memory_()

train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ def ensure_shared_grads(model, shared_model):
1818
shared_param._grad = param.grad
1919

2020

21-
def train(rank, args, shared_model):
21+
def train(rank, args, shared_model, optimizer=None):
2222
torch.manual_seed(args.seed + rank)
2323

2424
env = create_atari_env(args.env_name)
2525
env.seed(args.seed + rank)
2626

2727
model = ActorCritic(env.observation_space.shape[0], env.action_space)
2828

29-
optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)
29+
if optimizer is None:
30+
optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)
3031

3132
model.train()
3233

0 commit comments

Comments
 (0)