@@ -117,6 +117,7 @@ def run(
117117 lr : int = 0.01 ,
118118 download : bool = False ,
119119 image_size : int = 256 ,
120+ resume_from : dict = None ,
120121) -> None :
121122 bbox_params = A .BboxParams (format = "pascal_voc" )
122123 train_transform = A .Compose (
@@ -125,21 +126,22 @@ def run(
125126 )
126127 val_transform = A .Compose ([A .Resize (image_size , image_size ), ToTensorV2 ()], bbox_params = bbox_params )
127128
129+ download = local_rank == 0 and download
128130 train_dataset = Dataset (root = dataset_root , download = download , image_set = "train" , transforms = train_transform )
129131 val_dataset = Dataset (root = dataset_root , download = download , image_set = "val" , transforms = val_transform )
130132 vis_dataset = Subset (val_dataset , random .sample (range (len (val_dataset )), k = 16 ))
131133
134+ # for testing
135+ train_dataset = Subset (train_dataset , range (100 ))
136+ val_dataset = Subset (train_dataset , range (100 ))
137+
132138 train_dataloader = idist .auto_dataloader (
133139 train_dataset , batch_size = batch_size , shuffle = True , collate_fn = collate_fn , num_workers = 4
134140 )
135141 val_dataloader = DataLoader (val_dataset , batch_size = batch_size , shuffle = False , collate_fn = collate_fn , num_workers = 4 )
136142 vis_dataloader = DataLoader (vis_dataset , batch_size = batch_size , shuffle = False , collate_fn = collate_fn , num_workers = 4 )
137143
138- model = getattr (detection , model )(pretrained = True )
139- in_features = model .roi_heads .box_predictor .cls_score .in_features
140- model .roi_heads .box_predictor = FastRCNNPredictor (in_features , 21 )
141144 model = idist .auto_model (model )
142-
143145 scaler = GradScaler ()
144146 optimizer = SGD (lr = lr , params = model .parameters ())
145147 optimizer = idist .auto_optim (optimizer )
@@ -227,7 +229,13 @@ def submit_vis_images(engine):
227229 ProgressBar ().attach (trainer , losses )
228230 ProgressBar ().attach (evaluator )
229231
230- objects_to_checkpoint = {"trainer" : trainer , "model" : model , "optimizer" : optimizer , "lr_scheduler" : scheduler }
232+ objects_to_checkpoint = {
233+ "trainer" : trainer ,
234+ "model" : model ,
235+ "optimizer" : optimizer ,
236+ "lr_scheduler" : scheduler ,
237+ "scaler" : scaler ,
238+ }
231239 checkpoint = Checkpoint (
232240 to_save = objects_to_checkpoint ,
233241 save_handler = DiskSaver (log_dir , require_empty = False ),
@@ -236,6 +244,8 @@ def submit_vis_images(engine):
236244 global_step_transform = lambda * _ : trainer .state .epoch ,
237245 )
238246 evaluator .add_event_handler (Events .EPOCH_COMPLETED , checkpoint )
247+ if resume_from :
248+ Checkpoint .load_objects (objects_to_checkpoint , torch .load (resume_from ))
239249
240250 aim_logger .log_params (
241251 {
@@ -261,7 +271,7 @@ def submit_vis_images(engine):
261271
262272def main (
263273 experiment_name : str ,
264- gpus : Union [int , List [int ], str ] = "auto" ,
274+ gpus : Union [str , List [str ], str ] = "auto" ,
265275 nproc_per_node : Union [int , str ] = "auto" ,
266276 dataset_root : str = "./dataset" ,
267277 log_dir : str = "./log" ,
@@ -271,6 +281,7 @@ def main(
271281 lr : int = 0.01 ,
272282 download : bool = False ,
273283 image_size : int = 256 ,
284+ resume_from : str = None ,
274285) -> None :
275286 """
276287 Args:
@@ -288,14 +299,15 @@ def main(
288299 download: whether to automatically download dataset
289300 device: either cuda or cpu
290301 image_size: image size for training and validation
302+ resume_from: path of checkpoint to resume from
291303 """
292304 if model not in AVAILABLE_MODELS :
293305 raise RuntimeError (f"Invalid model name: { model } " )
294306
295307 if isinstance (gpus , int ):
296308 gpus = (gpus ,)
297309 if isinstance (gpus , tuple ):
298- os .environ ["CUDA_VISIBLE_DEVICES" ] = "," .join (gpus )
310+ os .environ ["CUDA_VISIBLE_DEVICES" ] = "," .join ([ str ( gpu ) for gpu in gpus ] )
299311 elif gpus == "auto" :
300312 gpus = tuple (range (torch .cuda .device_count ()))
301313 elif gpus == "none" :
@@ -308,6 +320,11 @@ def main(
308320 if nproc_per_node == "auto" :
309321 nproc_per_node = ngpu if ngpu > 0 else max (multiprocessing .cpu_count () // 2 , 1 )
310322
323+ # to precent multiple download for preatrined checkpoint, create model in the main process
324+ model = getattr (detection , model )(pretrained = True )
325+ in_features = model .roi_heads .box_predictor .cls_score .in_features
326+ model .roi_heads .box_predictor = FastRCNNPredictor (in_features , 21 )
327+
311328 with idist .Parallel (backend = backend , nproc_per_node = nproc_per_node ) as parallel :
312329 parallel .run (
313330 run ,
@@ -322,6 +339,7 @@ def main(
322339 lr ,
323340 download ,
324341 image_size ,
342+ resume_from ,
325343 )
326344
327345
0 commit comments