Skip to content

Commit 98ac171

Browse files
authored
Torch load issue fix with pytorch 2.6 (#543)
* Update python_utils.py for torch.load * Update tabular_datamodule.py for torch.load
1 parent 6cb5373 commit 98ac171

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ def _load_dataset_from_cache(self, tag: str = "train"):
758758
)
759759
elif self.cache_mode is self.CACHE_MODES.DISK:
760760
try:
761-
dataset = torch.load(self.cache_dir / f"{tag}_dataset")
761+
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
762762
except FileNotFoundError:
763763
raise FileNotFoundError(
764764
f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader"

src/pytorch_tabular/utils/python_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ def pl_load(
7474
"""
7575
if not isinstance(path_or_url, (str, Path)):
7676
# any sort of BytesIO or similar
77-
return torch.load(path_or_url, map_location=map_location)
77+
return torch.load(path_or_url, map_location=map_location, weights_only=False)
7878
if str(path_or_url).startswith("http"):
7979
return torch.hub.load_state_dict_from_url(
8080
str(path_or_url),
8181
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
8282
)
8383
fs = get_filesystem(path_or_url)
8484
with fs.open(path_or_url, "rb") as f:
85-
return torch.load(f, map_location=map_location)
85+
return torch.load(f, map_location=map_location, weights_only=False)
8686

8787

8888
def check_numpy(x):

0 commit comments

Comments
 (0)