Skip to content

Commit dd6fe53

Browse files
committed
add resume option
1 parent 6c074c1 commit dd6fe53

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

examples/contrib/detection/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# FasterRCNN Example with Ignite
1+
# Image Detection Example with Ignite
22

33
In this example, we show how to use _Ignite_ to train a image detection model with PyTorch built-in Faster RCNN implementation.
44

examples/contrib/detection/main.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def run(
117117
lr: int = 0.01,
118118
download: bool = False,
119119
image_size: int = 256,
120+
resume_from: dict = None,
120121
) -> None:
121122
bbox_params = A.BboxParams(format="pascal_voc")
122123
train_transform = A.Compose(
@@ -125,21 +126,22 @@ def run(
125126
)
126127
val_transform = A.Compose([A.Resize(image_size, image_size), ToTensorV2()], bbox_params=bbox_params)
127128

129+
download = local_rank == 0 and download
128130
train_dataset = Dataset(root=dataset_root, download=download, image_set="train", transforms=train_transform)
129131
val_dataset = Dataset(root=dataset_root, download=download, image_set="val", transforms=val_transform)
130132
vis_dataset = Subset(val_dataset, random.sample(range(len(val_dataset)), k=16))
131133

134+
# for testing
135+
train_dataset = Subset(train_dataset, range(100))
136+
val_dataset = Subset(train_dataset, range(100))
137+
132138
train_dataloader = idist.auto_dataloader(
133139
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4
134140
)
135141
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=4)
136142
vis_dataloader = DataLoader(vis_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=4)
137143

138-
model = getattr(detection, model)(pretrained=True)
139-
in_features = model.roi_heads.box_predictor.cls_score.in_features
140-
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 21)
141144
model = idist.auto_model(model)
142-
143145
scaler = GradScaler()
144146
optimizer = SGD(lr=lr, params=model.parameters())
145147
optimizer = idist.auto_optim(optimizer)
@@ -227,7 +229,13 @@ def submit_vis_images(engine):
227229
ProgressBar().attach(trainer, losses)
228230
ProgressBar().attach(evaluator)
229231

230-
objects_to_checkpoint = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": scheduler}
232+
objects_to_checkpoint = {
233+
"trainer": trainer,
234+
"model": model,
235+
"optimizer": optimizer,
236+
"lr_scheduler": scheduler,
237+
"scaler": scaler,
238+
}
231239
checkpoint = Checkpoint(
232240
to_save=objects_to_checkpoint,
233241
save_handler=DiskSaver(log_dir, require_empty=False),
@@ -236,6 +244,8 @@ def submit_vis_images(engine):
236244
global_step_transform=lambda *_: trainer.state.epoch,
237245
)
238246
evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint)
247+
if resume_from:
248+
Checkpoint.load_objects(objects_to_checkpoint, torch.load(resume_from))
239249

240250
aim_logger.log_params(
241251
{
@@ -261,7 +271,7 @@ def submit_vis_images(engine):
261271

262272
def main(
263273
experiment_name: str,
264-
gpus: Union[int, List[int], str] = "auto",
274+
gpus: Union[str, List[str], str] = "auto",
265275
nproc_per_node: Union[int, str] = "auto",
266276
dataset_root: str = "./dataset",
267277
log_dir: str = "./log",
@@ -271,6 +281,7 @@ def main(
271281
lr: int = 0.01,
272282
download: bool = False,
273283
image_size: int = 256,
284+
resume_from: str = None,
274285
) -> None:
275286
"""
276287
Args:
@@ -288,14 +299,15 @@ def main(
288299
download: whether to automatically download dataset
289300
device: either cuda or cpu
290301
image_size: image size for training and validation
302+
resume_from: path of checkpoint to resume from
291303
"""
292304
if model not in AVAILABLE_MODELS:
293305
raise RuntimeError(f"Invalid model name: {model}")
294306

295307
if isinstance(gpus, int):
296308
gpus = (gpus,)
297309
if isinstance(gpus, tuple):
298-
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(gpus)
310+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in gpus])
299311
elif gpus == "auto":
300312
gpus = tuple(range(torch.cuda.device_count()))
301313
elif gpus == "none":
@@ -308,6 +320,11 @@ def main(
308320
if nproc_per_node == "auto":
309321
nproc_per_node = ngpu if ngpu > 0 else max(multiprocessing.cpu_count() // 2, 1)
310322

323+
# to precent multiple download for preatrined checkpoint, create model in the main process
324+
model = getattr(detection, model)(pretrained=True)
325+
in_features = model.roi_heads.box_predictor.cls_score.in_features
326+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 21)
327+
311328
with idist.Parallel(backend=backend, nproc_per_node=nproc_per_node) as parallel:
312329
parallel.run(
313330
run,
@@ -322,6 +339,7 @@ def main(
322339
lr,
323340
download,
324341
image_size,
342+
resume_from,
325343
)
326344

327345

0 commit comments

Comments
 (0)