diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 7fd68dbf..7b1e89f8 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -758,7 +758,14 @@ def _load_dataset_from_cache(self, tag: str = "train"): ) elif self.cache_mode is self.CACHE_MODES.DISK: try: - dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False) + # get the torch version + torch_version = torch.__version__ + if torch_version < "2.6": + dataset = torch.load( + self.cache_dir / f"{tag}_dataset" + ) # fix for torch version change of torch.load + elif torch_version >= "2.6": + dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False) except FileNotFoundError: raise FileNotFoundError( f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader" diff --git a/src/pytorch_tabular/utils/python_utils.py b/src/pytorch_tabular/utils/python_utils.py index 57176fdc..e08503ed 100644 --- a/src/pytorch_tabular/utils/python_utils.py +++ b/src/pytorch_tabular/utils/python_utils.py @@ -74,7 +74,12 @@ def pl_load( """ if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similar - return torch.load(path_or_url, map_location=map_location, weights_only=False) + # get the torch version + torch_version = torch.__version__ + if torch_version < "2.6": + return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6 + elif torch_version >= "2.6": + return torch.load(path_or_url, map_location=map_location, weights_only=False) if str(path_or_url).startswith("http"): return torch.hub.load_state_dict_from_url( str(path_or_url), @@ -82,7 +87,11 @@ def pl_load( ) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: - return torch.load(f, map_location=map_location, weights_only=False) + torch_version = torch.__version__ + if torch_version < "2.6": + return torch.load(f, map_location=map_location) # for torch version < 2.6 + elif torch_version >= "2.6": + return torch.load(f, map_location=map_location, weights_only=False) def check_numpy(x):