Skip to content

Commit ead61fe

Browse files
committed
add debug
1 parent ce7e98c commit ead61fe

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

main.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from test import test
1616
from utils import logger
1717
import my_optim
18+
import objgraph
1819

1920
logger = logger.getLogger('main')
2021

@@ -40,13 +41,16 @@
4041
help='environment to train on (default: PongDeterministic-v3)')
4142
parser.add_argument('--no-shared', default=False, metavar='O',
4243
help='use an optimizer without shared momentum.')
44+
parser.add_argument('--max-iters', type=int, default=50000,
45+
help='maximum iterations per process.')
4346

47+
parser.add_argument('--debug', action='store_true', default=False,
48+
help='run in a way its easier to debug')
4449

4550
if __name__ == '__main__':
4651
args = parser.parse_args()
4752

4853
torch.manual_seed(args.seed)
49-
#import ipdb; ipdb.set_trace()
5054
env = create_atari_env(args.env_name)
5155
shared_model = ActorCritic(
5256
env.observation_space.shape[0], env.action_space)
@@ -63,10 +67,15 @@
6367
p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
6468
p.start()
6569
processes.append(p)
70+
71+
if args.debug:
72+
# run only one process in a main, easier to debug
73+
train(0, args, shared_model, optimizer)
74+
else:
75+
for rank in range(0, args.num_processes):
76+
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
77+
p.start()
78+
processes.append(p)
6679

67-
for rank in range(0, args.num_processes):
68-
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
69-
p.start()
70-
processes.append(p)
7180
for p in processes:
7281
p.join()

train.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from torch.autograd import Variable
1212
from torchvision import datasets, transforms
1313
from utils import logger
14+
#import objgraph
15+
#from memory_profiler import profile
1416

1517

1618
logger = logger.getLogger('main')
@@ -22,7 +24,7 @@ def ensure_shared_grads(model, shared_model):
2224
return
2325
shared_param._grad = param.grad
2426

25-
27+
#@profile
2628
def train(rank, args, shared_model, optimizer=None):
2729
torch.manual_seed(args.seed + rank)
2830

@@ -44,12 +46,18 @@ def train(rank, args, shared_model, optimizer=None):
4446

4547
iteration = 0
4648
while True:
49+
if iteration == args.max_iters:
50+
logger.info('Max iteration {} reached..'.format(args.max_iters))
51+
break
4752

48-
iteration += 1 #TODO remove later
4953
if iteration % 100 == 0 and rank == 0:
50-
logger.info('Memory usage of one proc: {} (mb)'.format(int(resource.getrusage(
51-
resource.RUSAGE_SELF).ru_maxrss) / 1024 / 1024))
54+
# for debugging purposes
55+
mem_used = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
56+
mem_used_mb = mem_used / 1024
5257

58+
logger.info('Memory usage of one proc: {} (mb)'.format(mem_used_mb))
59+
60+
iteration += 1
5361
episode_length += 1
5462

5563
# Sync with the shared model

0 commit comments

Comments
 (0)