Skip to content

Commit fb6ba0a

Browse files
authored
Make Checkpoint.load_objects to accept str and load internally (#2303) (#2305)
* Make Checkpoint.load_objects to accept str and load internally (#2303) * modify error message * Add test for Checkpoint.load_objects * fix test messages to match function error message
1 parent 7c596fc commit fb6ba0a

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

ignite/handlers/checkpoint.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -507,14 +507,14 @@ def _check_objects(objs: Mapping, attr: str) -> None:
507507
raise TypeError(f"Object {type(obj)} should have `{attr}` method")
508508

509509
@staticmethod
510-
def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
510+
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None:
511511
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.
512512
513513
Args:
514514
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
515-
checkpoint: a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
516-
"optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain directly
517-
corresponding state_dict.
515+
checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
516+
"optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain
517+
directly corresponding state_dict.
518518
kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
519519
the user to load part of the pretrained model (useful for example, in Transfer Learning)
520520
@@ -537,18 +537,29 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
537537
checkpoint = torch.load(checkpoint_fp)
538538
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
539539
540+
# or using a string for checkpoint filepath
541+
542+
to_load = to_save
543+
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
544+
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
545+
540546
Note:
541547
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
542548
`DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).
543549
544550
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
545551
torch.nn.parallel.DistributedDataParallel.html
546552
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
547-
548553
"""
554+
555+
if isinstance(checkpoint, str):
556+
checkpoint_obj = torch.load(checkpoint)
557+
else:
558+
checkpoint_obj = checkpoint
559+
549560
Checkpoint._check_objects(to_load, "load_state_dict")
550-
if not isinstance(checkpoint, collections.Mapping):
551-
raise TypeError(f"Argument checkpoint should be a dictionary, but given {type(checkpoint)}")
561+
if not isinstance(checkpoint, (collections.Mapping, str)):
562+
raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}")
552563

553564
if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]):
554565
warnings.warn("kwargs contains keys other than strict and these will be ignored")
@@ -557,22 +568,22 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
557568
if len(to_load) == 1:
558569
# single object and checkpoint is directly a state_dict
559570
key, obj = list(to_load.items())[0]
560-
if key not in checkpoint:
571+
if key not in checkpoint_obj:
561572
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
562573
obj = obj.module
563-
obj.load_state_dict(checkpoint, strict=is_state_dict_strict)
574+
obj.load_state_dict(checkpoint_obj, strict=is_state_dict_strict)
564575
return
565576

566577
# multiple objects to load
567578
for k, obj in to_load.items():
568-
if k not in checkpoint:
579+
if k not in checkpoint_obj:
569580
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
570581
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
571582
obj = obj.module
572583
if isinstance(obj, torch.nn.Module):
573-
obj.load_state_dict(checkpoint[k], strict=is_state_dict_strict)
584+
obj.load_state_dict(checkpoint_obj[k], strict=is_state_dict_strict)
574585
else:
575-
obj.load_state_dict(checkpoint[k])
586+
obj.load_state_dict(checkpoint_obj[k])
576587

577588
def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
578589
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.

tests/ignite/handlers/test_checkpoint.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ def test_save_model_optimizer_lr_scheduler_with_validation(dirname):
10701070

10711071
def test_checkpoint_load_objects():
10721072

1073-
with pytest.raises(TypeError, match=r"Argument checkpoint should be a dictionary"):
1073+
with pytest.raises(TypeError, match=r"Argument checkpoint should be a string or a dictionary"):
10741074
Checkpoint.load_objects({}, [])
10751075

10761076
with pytest.raises(TypeError, match=r"should have `load_state_dict` method"):
@@ -1107,6 +1107,17 @@ def _get_multiple_objs_to_save():
11071107
trainer = Engine(lambda e, b: None)
11081108
trainer.state = State(epoch=0, iteration=0)
11091109

1110+
# case: load from filepath
1111+
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
1112+
to_save = _get_multiple_objs_to_save()
1113+
handler(trainer, to_save)
1114+
fname = handler.last_checkpoint
1115+
assert isinstance(fname, str)
1116+
assert os.path.join(dirname, _PREFIX) in fname
1117+
assert os.path.exists(fname)
1118+
Checkpoint.load_objects(to_save, fname)
1119+
os.remove(fname)
1120+
11101121
# case: multiple objects
11111122
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
11121123
to_save = _get_multiple_objs_to_save()
@@ -1142,6 +1153,7 @@ def _get_multiple_objs_to_save():
11421153
assert os.path.exists(fname)
11431154
loaded_objects = torch.load(fname)
11441155
Checkpoint.load_objects(to_save, loaded_objects)
1156+
os.remove(fname)
11451157

11461158

11471159
def test_load_checkpoint_with_different_num_classes(dirname):

0 commit comments

Comments
 (0)