|
4 | 4 | import os |
5 | 5 | import sys |
6 | 6 | import math |
| 7 | +import time |
7 | 8 |
|
8 | 9 | import torch |
9 | 10 | import torch.optim as optim |
10 | 11 | import torch.multiprocessing as mp |
11 | 12 | import torch.nn as nn |
12 | 13 | import torch.nn.functional as F |
| 14 | +import tensorboard_logger as tb |
| 15 | + |
| 16 | +import my_optim |
13 | 17 | from envs import create_atari_env |
14 | 18 | from model import ActorCritic |
15 | 19 | from train import train |
16 | 20 | from test import test |
17 | 21 | from utils import logger |
18 | | -import my_optim |
| 22 | +from utils.shared_memory import SharedCounter |
| 23 | + |
19 | 24 |
|
20 | 25 | logger = logger.getLogger('main') |
21 | 26 |
|
|
41 | 46 | help='environment to train on (default: PongDeterministic-v3)') |
42 | 47 | parser.add_argument('--no-shared', default=False, metavar='O', |
43 | 48 | help='use an optimizer without shared momentum.') |
44 | | -parser.add_argument('--max-iters', type=int, default=math.inf, |
45 | | - help='maximum iterations per process.') |
46 | | - |
| 49 | +parser.add_argument('--max-episode-count', type=int, default=math.inf, |
| 50 | + help='maximum number of episodes to run per process.') |
47 | 51 | parser.add_argument('--debug', action='store_true', default=False, |
48 | 52 | help='run in a way its easier to debug') |
| 53 | +parser.add_argument('--short-description', default='no_descr', |
| 54 | + help='Short description of the run params, (used in tensorboard)') |
| 55 | + |
| 56 | +def setup_loggings(args): |
| 57 | + logger.debug('CONFIGURATION: {}'.format(args)) |
| 58 | + |
| 59 | + cur_path = os.path.dirname(os.path.realpath(__file__)) |
| 60 | + args.summ_base_dir = (cur_path+'/runs/{}/{}({})').format(args.env_name, |
| 61 | + time.strftime('%d.%m-%H.%M'), args.short_description) |
| 62 | + logger.info('logging run logs to {}'.format(args.summ_base_dir)) |
| 63 | + tb.configure(args.summ_base_dir) |
49 | 64 |
|
50 | 65 | if __name__ == '__main__': |
51 | 66 | args = parser.parse_args() |
| 67 | + setup_loggings(args) |
52 | 68 |
|
53 | 69 | torch.manual_seed(args.seed) |
54 | 70 | env = create_atari_env(args.env_name) |
|
61 | 77 | else: |
62 | 78 | optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr) |
63 | 79 | optimizer.share_memory() |
64 | | - |
| 80 | + |
| 81 | + gl_step_cnt = SharedCounter() |
65 | 82 |
|
66 | 83 | if not args.debug: |
67 | 84 | processes = [] |
68 | 85 |
|
69 | | - p = mp.Process(target=test, args=(args.num_processes, args, shared_model)) |
| 86 | + p = mp.Process(target=test, args=(args.num_processes, args, |
| 87 | + shared_model, gl_step_cnt)) |
70 | 88 | p.start() |
71 | 89 | processes.append(p) |
72 | 90 | for rank in range(0, args.num_processes): |
73 | | - p = mp.Process(target=train, args=(rank, args, shared_model, optimizer)) |
| 91 | + p = mp.Process(target=train, args=(rank, args, shared_model, |
| 92 | + gl_step_cnt, optimizer)) |
74 | 93 | p.start() |
75 | 94 | processes.append(p) |
76 | 95 | for p in processes: |
77 | 96 | p.join() |
78 | 97 | else: ## debug is enabled |
79 | 98 | # run only one process in a main, easier to debug |
80 | | - train(0, args, shared_model, optimizer) |
| 99 | + train(0, args, shared_model, gl_step_cnt, optimizer) |
0 commit comments