Skip to content

Commit d5e7171

Browse files
authored
Merge branch 'main' into fix-training-on-mps
2 parents 0e1c996 + 155f29b commit d5e7171

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,14 @@ def _load_dataset_from_cache(self, tag: str = "train"):
756756
raise AttributeError(f"{tag}_dataset not found in memory. Please provide the data for {tag} dataloader")
757757
elif self.cache_mode is self.CACHE_MODES.DISK:
758758
try:
759-
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
759+
# get the torch version
760+
torch_version = torch.__version__
761+
if torch_version < "2.6":
762+
dataset = torch.load(
763+
self.cache_dir / f"{tag}_dataset"
764+
) # fix for torch version change of torch.load
765+
elif torch_version >= "2.6":
766+
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
760767
except FileNotFoundError:
761768
raise FileNotFoundError(
762769
f"{tag}_dataset not found in {self.cache_dir}. Please provide the data for {tag} dataloader"

src/pytorch_tabular/utils/python_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,24 @@ def pl_load(
7474
"""
7575
if not isinstance(path_or_url, (str, Path)):
7676
# any sort of BytesIO or similar
77-
return torch.load(path_or_url, map_location=map_location, weights_only=False)
77+
# get the torch version
78+
torch_version = torch.__version__
79+
if torch_version < "2.6":
80+
return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6
81+
elif torch_version >= "2.6":
82+
return torch.load(path_or_url, map_location=map_location, weights_only=False)
7883
if str(path_or_url).startswith("http"):
7984
return torch.hub.load_state_dict_from_url(
8085
str(path_or_url),
8186
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
8287
)
8388
fs = get_filesystem(path_or_url)
8489
with fs.open(path_or_url, "rb") as f:
85-
return torch.load(f, map_location=map_location, weights_only=False)
90+
torch_version = torch.__version__
91+
if torch_version < "2.6":
92+
return torch.load(f, map_location=map_location) # for torch version < 2.6
93+
elif torch_version >= "2.6":
94+
return torch.load(f, map_location=map_location, weights_only=False)
8695

8796

8897
def check_numpy(x):

0 commit comments

Comments
 (0)