Skip to content

Commit 6eecf56

Browse files
committed
tiny refactor
1 parent ead61fe commit 6eecf56

File tree

3 files changed

+25
-27
lines changed

3 files changed

+25
-27
lines changed

main.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from test import test
1616
from utils import logger
1717
import my_optim
18-
import objgraph
1918

2019
logger = logger.getLogger('main')
2120

@@ -62,20 +61,19 @@
6261
optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr)
6362
optimizer.share_memory()
6463

65-
processes = []
66-
67-
p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
68-
p.start()
69-
processes.append(p)
7064

71-
if args.debug:
72-
# run only one process in a main, easier to debug
73-
train(0, args, shared_model, optimizer)
74-
else:
65+
if not args.debug:
66+
processes = []
67+
68+
p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
69+
p.start()
70+
processes.append(p)
7571
for rank in range(0, args.num_processes):
7672
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
7773
p.start()
7874
processes.append(p)
79-
80-
for p in processes:
81-
p.join()
75+
for p in processes:
76+
p.join()
77+
else: ## debug is enabled
78+
# run only one process in a main, easier to debug
79+
train(0, args, shared_model, optimizer)

model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ def __init__(self, num_inputs, action_space):
4444
self.lstm = nn.LSTMCell(32 * 3 * 3, 256)
4545

4646
num_outputs = action_space.n
47+
4748
self.critic_linear = nn.Linear(256, 1)
4849
self.actor_linear = nn.Linear(256, num_outputs)
50+
#self.critic_linear = nn.Linear(288, 1)
51+
#self.actor_linear = nn.Linear(288, num_outputs)
4952

5053
self.apply(weights_init)
5154
self.actor_linear.weight.data = normalized_columns_initializer(
@@ -66,7 +69,6 @@ def forward(self, inputs):
6669
x = F.elu(self.conv2(x))
6770
x = F.elu(self.conv3(x))
6871
x = F.elu(self.conv4(x))
69-
7072
x = x.view(-1, 32 * 3 * 3)
7173
hx, cx = self.lstm(x, (hx, cx))
7274
x = hx

train.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
import resource
5+
import gc
56

67
import torch
78
import torch.nn.functional as F
@@ -11,20 +12,15 @@
1112
from torch.autograd import Variable
1213
from torchvision import datasets, transforms
1314
from utils import logger
14-
#import objgraph
15-
#from memory_profiler import profile
16-
1715

1816
logger = logger.getLogger('main')
1917

20-
2118
def ensure_shared_grads(model, shared_model):
2219
for param, shared_param in zip(model.parameters(), shared_model.parameters()):
2320
if shared_param.grad is not None:
2421
return
2522
shared_param._grad = param.grad
2623

27-
#@profile
2824
def train(rank, args, shared_model, optimizer=None):
2925
torch.manual_seed(args.seed + rank)
3026

@@ -45,18 +41,24 @@ def train(rank, args, shared_model, optimizer=None):
4541
episode_length = 0
4642

4743
iteration = 0
44+
4845
while True:
46+
47+
values = []
48+
log_probs = []
49+
rewards = []
50+
entropies = []
51+
4952
if iteration == args.max_iters:
5053
logger.info('Max iteration {} reached..'.format(args.max_iters))
5154
break
5255

53-
if iteration % 100 == 0 and rank == 0:
54-
# for debugging purposes
56+
if iteration % 200 == 0 and rank == 0:
5557
mem_used = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
5658
mem_used_mb = mem_used / 1024
57-
5859
logger.info('Memory usage of one proc: {} (mb)'.format(mem_used_mb))
59-
60+
61+
6062
iteration += 1
6163
episode_length += 1
6264

@@ -69,10 +71,6 @@ def train(rank, args, shared_model, optimizer=None):
6971
cx = Variable(cx.data)
7072
hx = Variable(hx.data)
7173

72-
values = []
73-
log_probs = []
74-
rewards = []
75-
entropies = []
7674

7775
for step in range(args.num_steps):
7876
value, logit, (hx, cx) = model(

0 commit comments

Comments
 (0)