Skip to content

Commit f4b0221

Browse files
Make tensor dtypes np.float32 for MPS devices (#540)
numpy defaults to numpy.float64 when they should be numpy.float32 This caused training to fail on MPS devices but it works on my M1 with this. Co-authored-by: Manu Joseph V <manujosephv@gmail.com>
1 parent 155f29b commit f4b0221

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

src/pytorch_tabular/ssl_models/common/noise_generators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class SwapNoiseCorrupter(nn.Module):
1818

1919
def __init__(self, probas):
2020
super().__init__()
21-
self.probas = torch.from_numpy(np.array(probas))
21+
self.probas = torch.from_numpy(np.array(probas, dtype=np.float32))
2222

2323
def forward(self, x):
2424
should_swap = torch.bernoulli(self.probas.to(x.device) * torch.ones(x.shape).to(x.device))

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
if isinstance(target, str):
6868
self.y = self.y.reshape(-1, 1) # .astype(np.int64)
6969
else:
70-
self.y = np.zeros((self.n, 1)) # .astype(np.int64)
70+
self.y = np.zeros((self.n, 1), dtype=np.float32) # .astype(np.int64)
7171

7272
if task == "classification":
7373
self.y = self.y.astype(np.int64)
@@ -502,7 +502,7 @@ def _cache_dataset(self):
502502

503503
def split_train_val(self, train):
504504
logger.debug(
505-
"No validation data provided." f" Using {self.config.validation_split*100}% of train data as validation"
505+
f"No validation data provided. Using {self.config.validation_split * 100}% of train data as validation"
506506
)
507507
val_idx = train.sample(
508508
int(self.config.validation_split * len(train)),
@@ -753,9 +753,7 @@ def _load_dataset_from_cache(self, tag: str = "train"):
753753
try:
754754
dataset = getattr(self, f"_{tag}_dataset")
755755
except AttributeError:
756-
raise AttributeError(
757-
f"{tag}_dataset not found in memory. Please provide the data for" f" {tag} dataloader"
758-
)
756+
raise AttributeError(f"{tag}_dataset not found in memory. Please provide the data for {tag} dataloader")
759757
elif self.cache_mode is self.CACHE_MODES.DISK:
760758
try:
761759
# get the torch version
@@ -768,10 +766,10 @@ def _load_dataset_from_cache(self, tag: str = "train"):
768766
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
769767
except FileNotFoundError:
770768
raise FileNotFoundError(
771-
f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader"
769+
f"{tag}_dataset not found in {self.cache_dir}. Please provide the data for {tag} dataloader"
772770
)
773771
elif self.cache_mode is self.CACHE_MODES.INFERENCE:
774-
raise RuntimeError("Cannot load dataset in inference mode. Use" " `prepare_inference_dataloader` instead")
772+
raise RuntimeError("Cannot load dataset in inference mode. Use `prepare_inference_dataloader` instead")
775773
else:
776774
raise ValueError(f"{self.cache_mode} is not a valid cache mode")
777775
return dataset

0 commit comments

Comments
 (0)