Skip to content

Commit e8ee018

Browse files
committed
-- added map_location & strict to load_checkpoint
1 parent 9ac8cc4 commit e8ee018

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

pytorch_tabular/tabular_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,11 +710,16 @@ def save_model(self, dir: str):
710710
)
711711

712712
@classmethod
713-
def load_from_checkpoint(cls, dir: str):
713+
def load_from_checkpoint(cls, dir: str, map_location = None, strict=True):
714714
"""Loads a saved model from the directory
715715
716716
Args:
717717
dir (str): The directory where the model wa saved, along with the checkpoints
718+
map_location (Union[Dict[str, str], str, device, int, Callable, None]) – If your checkpoint
719+
saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map
720+
to the new setup. The behaviour is the same as in torch.load()
721+
strict (bool) – Whether to strictly enforce that the keys in checkpoint_path match the keys
722+
returned by this module’s state dict. Default: True.
718723
719724
Returns:
720725
TabularModel: The saved TabularModel
@@ -755,7 +760,7 @@ def load_from_checkpoint(cls, dir: str):
755760

756761
# Initializing with default metrics, losses, and optimizers. Will revert once initialized
757762
model = model_callable.load_from_checkpoint(
758-
checkpoint_path=os.path.join(dir, "model.ckpt"), **model_args
763+
checkpoint_path=os.path.join(dir, "model.ckpt"),map_location=map_location, strict=strict, **model_args
759764
)
760765
# else:
761766
# # Initializing with default values

0 commit comments

Comments
 (0)