@@ -507,14 +507,14 @@ def _check_objects(objs: Mapping, attr: str) -> None:
507507 raise TypeError (f"Object { type (obj )} should have `{ attr } ` method" )
508508
509509 @staticmethod
510- def load_objects (to_load : Mapping , checkpoint : Mapping , ** kwargs : Any ) -> None :
510+ def load_objects (to_load : Mapping , checkpoint : Union [ str , Mapping ] , ** kwargs : Any ) -> None :
511511 """Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.
512512
513513 Args:
514514 to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
515- checkpoint: a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
516- "optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain directly
517- corresponding state_dict.
515+ checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
516+ "optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain
517+ directly corresponding state_dict.
518518 kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
519519 the user to load part of the pretrained model (useful for example, in Transfer Learning)
520520
@@ -537,18 +537,29 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
537537 checkpoint = torch.load(checkpoint_fp)
538538 Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
539539
540+ # or using a string for checkpoint filepath
541+
542+ to_load = to_save
543+ checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
544+ Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
545+
540546 Note:
541547 If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
542548 `DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).
543549
544550 .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
545551 torch.nn.parallel.DistributedDataParallel.html
546552 .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
547-
548553 """
554+
555+ if isinstance (checkpoint , str ):
556+ checkpoint_obj = torch .load (checkpoint )
557+ else :
558+ checkpoint_obj = checkpoint
559+
549560 Checkpoint ._check_objects (to_load , "load_state_dict" )
550- if not isinstance (checkpoint , collections .Mapping ):
551- raise TypeError (f"Argument checkpoint should be a dictionary, but given { type (checkpoint )} " )
561+ if not isinstance (checkpoint , ( collections .Mapping , str ) ):
562+ raise TypeError (f"Argument checkpoint should be a string or a dictionary, but given { type (checkpoint )} " )
552563
553564 if len (kwargs ) > 1 or any (k for k in kwargs if k not in ["strict" ]):
554565 warnings .warn ("kwargs contains keys other than strict and these will be ignored" )
@@ -557,22 +568,22 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
557568 if len (to_load ) == 1 :
558569 # single object and checkpoint is directly a state_dict
559570 key , obj = list (to_load .items ())[0 ]
560- if key not in checkpoint :
571+ if key not in checkpoint_obj :
561572 if isinstance (obj , (nn .DataParallel , nn .parallel .DistributedDataParallel )):
562573 obj = obj .module
563- obj .load_state_dict (checkpoint , strict = is_state_dict_strict )
574+ obj .load_state_dict (checkpoint_obj , strict = is_state_dict_strict )
564575 return
565576
566577 # multiple objects to load
567578 for k , obj in to_load .items ():
568- if k not in checkpoint :
579+ if k not in checkpoint_obj :
569580 raise ValueError (f"Object labeled by '{ k } ' from `to_load` is not found in the checkpoint" )
570581 if isinstance (obj , (nn .DataParallel , nn .parallel .DistributedDataParallel )):
571582 obj = obj .module
572583 if isinstance (obj , torch .nn .Module ):
573- obj .load_state_dict (checkpoint [k ], strict = is_state_dict_strict )
584+ obj .load_state_dict (checkpoint_obj [k ], strict = is_state_dict_strict )
574585 else :
575- obj .load_state_dict (checkpoint [k ])
586+ obj .load_state_dict (checkpoint_obj [k ])
576587
577588 def state_dict (self ) -> "OrderedDict[str, List[Tuple[int, str]]]" :
578589 """Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
0 commit comments