Skip to content

Commit ce7e98c

Browse files
committed
add memory log
1 parent 5d9b07d commit ce7e98c

File tree

4 files changed

+70
-2
lines changed

4 files changed

+70
-2
lines changed

main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
from model import ActorCritic
1414
from train import train
1515
from test import test
16+
from utils import logger
1617
import my_optim
1718

19+
logger = logger.getLogger('main')
20+
1821
# Based on
1922
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
2023
# Training settings
@@ -43,7 +46,7 @@
4346
args = parser.parse_args()
4447

4548
torch.manual_seed(args.seed)
46-
49+
#import ipdb; ipdb.set_trace()
4750
env = create_atari_env(args.env_name)
4851
shared_model = ActorCritic(
4952
env.observation_space.shape[0], env.action_space)

test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from torchvision import datasets, transforms
1212
import time
1313
from collections import deque
14+
from utils import logger
1415

16+
logger = logger.getLogger('test')
1517

1618
def test(rank, args, shared_model):
1719
torch.manual_seed(args.seed + rank)
@@ -59,7 +61,7 @@ def test(rank, args, shared_model):
5961
done = True
6062

6163
if done:
62-
print("Time {}, episode reward {}, episode length {}".format(
64+
logger.info("Time {}, episode reward {}, episode length {}".format(
6365
time.strftime("%Hh %Mm %Ss",
6466
time.gmtime(time.time() - start_time)),
6567
reward_sum, episode_length))

train.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import os
33
import sys
4+
import resource
45

56
import torch
67
import torch.nn.functional as F
@@ -9,6 +10,10 @@
910
from model import ActorCritic
1011
from torch.autograd import Variable
1112
from torchvision import datasets, transforms
13+
from utils import logger
14+
15+
16+
logger = logger.getLogger('main')
1217

1318

1419
def ensure_shared_grads(model, shared_model):
@@ -36,8 +41,17 @@ def train(rank, args, shared_model, optimizer=None):
3641
done = True
3742

3843
episode_length = 0
44+
45+
iteration = 0
3946
while True:
47+
48+
iteration += 1 #TODO remove later
49+
if iteration % 100 == 0 and rank == 0:
50+
logger.info('Memory usage of one proc: {} (mb)'.format(int(resource.getrusage(
51+
resource.RUSAGE_SELF).ru_maxrss) / 1024 / 1024))
52+
4053
episode_length += 1
54+
4155
# Sync with the shared model
4256
model.load_state_dict(shared_model.state_dict())
4357
if done:

utils/logger.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
import logging
4+
import logging.config
5+
6+
7+
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
8+
LOGGING = {
9+
'version': 1,
10+
'disable_existing_loggers': True,
11+
'formatters': {
12+
'verbose': {
13+
'format': "[%(asctime)s] %(levelname)s " \
14+
"[%(threadName)s:%(lineno)s] %(message)s",
15+
'datefmt': "%Y-%m-%d %H:%M:%S"
16+
},
17+
'simple': {
18+
'format': '%(levelname)s %(message)s'
19+
},
20+
},
21+
'handlers': {
22+
'console': {
23+
'level': LOG_LEVEL,
24+
'class': 'logging.StreamHandler',
25+
'formatter': 'verbose'
26+
},
27+
'file': {
28+
'level': LOG_LEVEL,
29+
'class': 'logging.handlers.RotatingFileHandler',
30+
'formatter': 'verbose',
31+
'filename': 'rl.log',
32+
'maxBytes': 10*10**6,
33+
'backupCount': 3
34+
}
35+
},
36+
'loggers': {
37+
'': {
38+
'handlers': ['console', 'file'],
39+
'level': LOG_LEVEL,
40+
},
41+
}
42+
}
43+
44+
45+
logging.config.dictConfig(LOGGING)
46+
47+
def getLogger(name):
48+
49+
return logging.getLogger(name)

0 commit comments

Comments
 (0)