Skip to content

Commit 89aefc7

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent df244cc commit 89aefc7

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,9 +761,11 @@ def _load_dataset_from_cache(self, tag: str = "train"):
761761
# get the torch version
762762
torch_version = torch.__version__
763763
if torch_version < "2.6":
764-
dataset = torch.load(self.cache_dir / f"{tag}_dataset") # fix for torch version change of torch.load
764+
dataset = torch.load(
765+
self.cache_dir / f"{tag}_dataset"
766+
) # fix for torch version change of torch.load
765767
elif torch_version >= "2.6":
766-
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
768+
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
767769
except FileNotFoundError:
768770
raise FileNotFoundError(
769771
f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader"

src/pytorch_tabular/utils/python_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def pl_load(
7777
# get the torch version
7878
torch_version = torch.__version__
7979
if torch_version < "2.6":
80-
return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6
80+
return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6
8181
elif torch_version >= "2.6":
8282
return torch.load(path_or_url, map_location=map_location, weights_only=False)
8383
if str(path_or_url).startswith("http"):
@@ -88,10 +88,9 @@ def pl_load(
8888
fs = get_filesystem(path_or_url)
8989
with fs.open(path_or_url, "rb") as f:
9090
if torch_version < "2.6":
91-
return torch.load(f, map_location=map_location) # for torch version < 2.6
91+
return torch.load(f, map_location=map_location) # for torch version < 2.6
9292
elif torch_version >= "2.6":
9393
return torch.load(f, map_location=map_location, weights_only=False)
94-
9594

9695

9796
def check_numpy(x):

0 commit comments

Comments
 (0)