22import os
33import sys
44import resource
5+ import gc
56
67import torch
78import torch .nn .functional as F
1112from torch .autograd import Variable
1213from torchvision import datasets , transforms
1314from utils import logger
14- #import objgraph
15- #from memory_profiler import profile
16-
1715
1816logger = logger .getLogger ('main' )
1917
20-
2118def ensure_shared_grads (model , shared_model ):
2219 for param , shared_param in zip (model .parameters (), shared_model .parameters ()):
2320 if shared_param .grad is not None :
2421 return
2522 shared_param ._grad = param .grad
2623
27- #@profile
2824def train (rank , args , shared_model , optimizer = None ):
2925 torch .manual_seed (args .seed + rank )
3026
@@ -45,18 +41,24 @@ def train(rank, args, shared_model, optimizer=None):
4541 episode_length = 0
4642
4743 iteration = 0
44+
4845 while True :
46+
47+ values = []
48+ log_probs = []
49+ rewards = []
50+ entropies = []
51+
4952 if iteration == args .max_iters :
5053 logger .info ('Max iteration {} reached..' .format (args .max_iters ))
5154 break
5255
53- if iteration % 100 == 0 and rank == 0 :
54- # for debugging purposes
56+ if iteration % 200 == 0 and rank == 0 :
5557 mem_used = int (resource .getrusage (resource .RUSAGE_SELF ).ru_maxrss )
5658 mem_used_mb = mem_used / 1024
57-
5859 logger .info ('Memory usage of one proc: {} (mb)' .format (mem_used_mb ))
59-
60+
61+
6062 iteration += 1
6163 episode_length += 1
6264
@@ -69,10 +71,6 @@ def train(rank, args, shared_model, optimizer=None):
6971 cx = Variable (cx .data )
7072 hx = Variable (hx .data )
7173
72- values = []
73- log_probs = []
74- rewards = []
75- entropies = []
7674
7775 for step in range (args .num_steps ):
7876 value , logit , (hx , cx ) = model (
0 commit comments