Skip to content

Commit 155f29b

Browse files
torch.load fix for all pytorch versions. (#554)
* torch.load fix for all pytorch versions. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update python_utils.py to fix minor bug --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 98ac171 commit 155f29b

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,14 @@ def _load_dataset_from_cache(self, tag: str = "train"):
758758
)
759759
elif self.cache_mode is self.CACHE_MODES.DISK:
760760
try:
761-
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
761+
# get the torch version
762+
torch_version = torch.__version__
763+
if torch_version < "2.6":
764+
dataset = torch.load(
765+
self.cache_dir / f"{tag}_dataset"
766+
) # fix for torch version change of torch.load
767+
elif torch_version >= "2.6":
768+
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
762769
except FileNotFoundError:
763770
raise FileNotFoundError(
764771
f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader"

src/pytorch_tabular/utils/python_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,24 @@ def pl_load(
7474
"""
7575
if not isinstance(path_or_url, (str, Path)):
7676
# any sort of BytesIO or similar
77-
return torch.load(path_or_url, map_location=map_location, weights_only=False)
77+
# get the torch version
78+
torch_version = torch.__version__
79+
if torch_version < "2.6":
80+
return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6
81+
elif torch_version >= "2.6":
82+
return torch.load(path_or_url, map_location=map_location, weights_only=False)
7883
if str(path_or_url).startswith("http"):
7984
return torch.hub.load_state_dict_from_url(
8085
str(path_or_url),
8186
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
8287
)
8388
fs = get_filesystem(path_or_url)
8489
with fs.open(path_or_url, "rb") as f:
85-
return torch.load(f, map_location=map_location, weights_only=False)
90+
torch_version = torch.__version__
91+
if torch_version < "2.6":
92+
return torch.load(f, map_location=map_location) # for torch version < 2.6
93+
elif torch_version >= "2.6":
94+
return torch.load(f, map_location=map_location, weights_only=False)
8695

8796

8897
def check_numpy(x):

0 commit comments

Comments
 (0)