@@ -710,11 +710,16 @@ def save_model(self, dir: str):
710710 )
711711
712712 @classmethod
713- def load_from_checkpoint (cls , dir : str ):
713+ def load_from_checkpoint (cls , dir : str , map_location = None , strict = True ):
714714 """Loads a saved model from the directory
715715
716716 Args:
717717 dir (str): The directory where the model wa saved, along with the checkpoints
718+ map_location (Union[Dict[str, str], str, device, int, Callable, None]) – If your checkpoint
719+ saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map
720+ to the new setup. The behaviour is the same as in torch.load()
721+ strict (bool) – Whether to strictly enforce that the keys in checkpoint_path match the keys
722+ returned by this module’s state dict. Default: True.
718723
719724 Returns:
720725 TabularModel: The saved TabularModel
@@ -755,7 +760,7 @@ def load_from_checkpoint(cls, dir: str):
755760
756761 # Initializing with default metrics, losses, and optimizers. Will revert once initialized
757762 model = model_callable .load_from_checkpoint (
758- checkpoint_path = os .path .join (dir , "model.ckpt" ), ** model_args
763+ checkpoint_path = os .path .join (dir , "model.ckpt" ),map_location = map_location , strict = strict , ** model_args
759764 )
760765 # else:
761766 # # Initializing with default values
0 commit comments