Skip to content

Commit 922f385

Browse files
committed
Make tensor dtypes np.float32 for MPS devices
numpy defaults to numpy.float64 when they should be numpy.float32 This caused training to fail on MPS devices but it works on my M1 with this.
1 parent 6cb5373 commit 922f385

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

src/pytorch_tabular/ssl_models/common/noise_generators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class SwapNoiseCorrupter(nn.Module):
1818

1919
def __init__(self, probas):
2020
super().__init__()
21-
self.probas = torch.from_numpy(np.array(probas))
21+
self.probas = torch.from_numpy(np.array(probas, dtype=np.float32))
22+
print(f"666 {self.probas.dtype} 666")
2223

2324
def forward(self, x):
2425
should_swap = torch.bernoulli(self.probas.to(x.device) * torch.ones(x.shape).to(x.device))

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
if isinstance(target, str):
6868
self.y = self.y.reshape(-1, 1) # .astype(np.int64)
6969
else:
70-
self.y = np.zeros((self.n, 1)) # .astype(np.int64)
70+
self.y = np.zeros((self.n, 1), dtype=np.float32) # .astype(np.int64)
7171

7272
if task == "classification":
7373
self.y = self.y.astype(np.int64)
@@ -502,7 +502,7 @@ def _cache_dataset(self):
502502

503503
def split_train_val(self, train):
504504
logger.debug(
505-
"No validation data provided." f" Using {self.config.validation_split*100}% of train data as validation"
505+
f"No validation data provided. Using {self.config.validation_split * 100}% of train data as validation"
506506
)
507507
val_idx = train.sample(
508508
int(self.config.validation_split * len(train)),
@@ -753,18 +753,16 @@ def _load_dataset_from_cache(self, tag: str = "train"):
753753
try:
754754
dataset = getattr(self, f"_{tag}_dataset")
755755
except AttributeError:
756-
raise AttributeError(
757-
f"{tag}_dataset not found in memory. Please provide the data for" f" {tag} dataloader"
758-
)
756+
raise AttributeError(f"{tag}_dataset not found in memory. Please provide the data for {tag} dataloader")
759757
elif self.cache_mode is self.CACHE_MODES.DISK:
760758
try:
761759
dataset = torch.load(self.cache_dir / f"{tag}_dataset")
762760
except FileNotFoundError:
763761
raise FileNotFoundError(
764-
f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader"
762+
f"{tag}_dataset not found in {self.cache_dir}. Please provide the data for {tag} dataloader"
765763
)
766764
elif self.cache_mode is self.CACHE_MODES.INFERENCE:
767-
raise RuntimeError("Cannot load dataset in inference mode. Use" " `prepare_inference_dataloader` instead")
765+
raise RuntimeError("Cannot load dataset in inference mode. Use `prepare_inference_dataloader` instead")
768766
else:
769767
raise ValueError(f"{self.cache_mode} is not a valid cache mode")
770768
return dataset

0 commit comments

Comments
 (0)