|
15 | 15 | from test import test |
16 | 16 | from utils import logger |
17 | 17 | import my_optim |
| 18 | +import objgraph |
18 | 19 |
|
19 | 20 | logger = logger.getLogger('main') |
20 | 21 |
|
|
40 | 41 | help='environment to train on (default: PongDeterministic-v3)') |
41 | 42 | parser.add_argument('--no-shared', default=False, metavar='O', |
42 | 43 | help='use an optimizer without shared momentum.') |
| 44 | +parser.add_argument('--max-iters', type=int, default=50000, |
| 45 | + help='maximum iterations per process.') |
43 | 46 |
|
| 47 | +parser.add_argument('--debug', action='store_true', default=False, |
| 48 | + help='run in a way its easier to debug') |
44 | 49 |
|
45 | 50 | if __name__ == '__main__': |
46 | 51 | args = parser.parse_args() |
47 | 52 |
|
48 | 53 | torch.manual_seed(args.seed) |
49 | | - #import ipdb; ipdb.set_trace() |
50 | 54 | env = create_atari_env(args.env_name) |
51 | 55 | shared_model = ActorCritic( |
52 | 56 | env.observation_space.shape[0], env.action_space) |
|
63 | 67 | p = mp.Process(target=test, args=(args.num_processes, args, shared_model)) |
64 | 68 | p.start() |
65 | 69 | processes.append(p) |
| 70 | + |
| 71 | + if args.debug: |
| 72 | + # run only one process in a main, easier to debug |
| 73 | + train(0, args, shared_model, optimizer) |
| 74 | + else: |
| 75 | + for rank in range(0, args.num_processes): |
| 76 | + p = mp.Process(target=train, args=(rank, args, shared_model, optimizer)) |
| 77 | + p.start() |
| 78 | + processes.append(p) |
66 | 79 |
|
67 | | - for rank in range(0, args.num_processes): |
68 | | - p = mp.Process(target=train, args=(rank, args, shared_model, optimizer)) |
69 | | - p.start() |
70 | | - processes.append(p) |
71 | 80 | for p in processes: |
72 | 81 | p.join() |
0 commit comments