diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index 35b771fe..b3d7a1ee 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -68,7 +68,7 @@ def transform(self, X): X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"]) if self.handle_unseen == "impute": - X_encoded[col].fillna(self._imputed, inplace=True) + X_encoded[col] = X_encoded[col].fillna(self._imputed) elif self.handle_unseen == "error": if np.unique(X_encoded[col]).shape[0] > mapping.shape[0]: raise ValueError(f"Unseen categories found in `{col}` column.") diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 2e410170..55aa500b 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -96,6 +96,8 @@ class DataConfig: handle_missing_values (bool): Whether to handle missing values in categorical columns as unknown + pickle_protocol (int): pickle protocol version passed to `torch.save` for dataset caching to disk + dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader @@ -179,6 +181,11 @@ class DataConfig: metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"}, ) + pickle_protocol: int = field( + default=2, + metadata={"help": "pickle protocol version passed to `torch.save` for dataset caching to disk"}, + ) + dataloader_kwargs: Dict[str, Any] = field( default_factory=dict, metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."}, @@ -351,8 +358,8 @@ class TrainerConfig: progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`. - precision (int): Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`.. - Choices are: [`32`,`16`,`64`]. + precision (str): Precision of the model. Defaults to `32`. See + https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision seed (int): Seed for random number generators. Defaults to 42 @@ -536,11 +543,10 @@ class TrainerConfig: default="rich", metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."}, ) - precision: int = field( - default=32, + precision: str = field( + default="32", metadata={ - "help": "Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`.", - "choices": [32, 16, 64], + "help": "Precision of the model. Defaults to `32`.", }, ) seed: int = field( diff --git a/src/pytorch_tabular/feature_extractor.py b/src/pytorch_tabular/feature_extractor.py index 33f84f09..424a03e4 100644 --- a/src/pytorch_tabular/feature_extractor.py +++ b/src/pytorch_tabular/feature_extractor.py @@ -79,15 +79,21 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: if k in ret_value.keys(): logits_predictions[k].append(ret_value[k].detach().cpu()) + logits_dfs = [] for k, v in logits_predictions.items(): v = torch.cat(v, dim=0).numpy() if v.ndim == 1: v = v.reshape(-1, 1) - for i in range(v.shape[-1]): - if v.shape[-1] > 1: - X_encoded[f"{k}_{i}"] = v[:, i] - else: - X_encoded[f"{k}"] = v[:, i] + if v.shape[-1] > 1: + temp_df = pd.DataFrame({f"{k}_{i}": v[:, i] for i in range(v.shape[-1])}) + else: + temp_df = pd.DataFrame({f"{k}": v[:, 0]}) + + # Append the temp DataFrame to the list + logits_dfs.append(temp_df) + + preds = pd.concat(logits_dfs, axis=1) + X_encoded = pd.concat([X_encoded, preds], axis=1) if self.drop_original: X_encoded.drop(columns=orig_features, inplace=True) diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index 12b5518c..824eb710 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -244,13 +244,14 @@ def _setup_metrics(self): else: self.metrics = self.custom_metrics - def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tensor: + def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str, sync_dist: bool = False) -> torch.Tensor: """Calculates the loss for the model. Args: output (Dict): The output dictionary from the model y (torch.Tensor): The target tensor tag (str): The tag to use for logging + sync_dist (bool): enable distributed sync of logs Returns: torch.Tensor: The loss value @@ -270,6 +271,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso on_step=False, logger=True, prog_bar=False, + sync_dist=sync_dist, ) if self.hparams.task == "regression": computed_loss = reg_loss @@ -284,6 +286,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso on_step=False, logger=True, prog_bar=False, + sync_dist=sync_dist, ) else: # TODO loss fails with batch size of 1? @@ -301,6 +304,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso on_step=False, logger=True, prog_bar=False, + sync_dist=sync_dist, ) start_index = end_index self.log( @@ -311,10 +315,13 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso # on_step=False, logger=True, prog_bar=True, + sync_dist=sync_dist, ) return computed_loss - def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> List[torch.Tensor]: + def calculate_metrics( + self, y: torch.Tensor, y_hat: torch.Tensor, tag: str, sync_dist: bool = False + ) -> List[torch.Tensor]: """Calculates the metrics for the model. Args: @@ -324,6 +331,8 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L tag (str): The tag to use for logging + sync_dist (bool): enable distributed sync of logs + Returns: List[torch.Tensor]: The list of metric values @@ -356,6 +365,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L on_step=False, logger=True, prog_bar=False, + sync_dist=sync_dist, ) _metrics.append(_metric) avg_metric = torch.stack(_metrics, dim=0).sum() @@ -379,6 +389,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L on_step=False, logger=True, prog_bar=False, + sync_dist=sync_dist, ) _metrics.append(_metric) start_index = end_index @@ -391,6 +402,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L on_step=False, logger=True, prog_bar=True, + sync_dist=sync_dist, ) return metrics @@ -523,19 +535,19 @@ def validation_step(self, batch, batch_idx): # fetched from the batch y = batch["target"] if y is None else y y_hat = output["logits"] - self.calculate_loss(output, y, tag="valid") - self.calculate_metrics(y, y_hat, tag="valid") + self.calculate_loss(output, y, tag="valid", sync_dist=True) + self.calculate_metrics(y, y_hat, tag="valid", sync_dist=True) return y_hat, y def test_step(self, batch, batch_idx): with torch.no_grad(): output, y = self.forward_pass(batch) - # y is not None for SSL task.Rest of the tasks target is + # y is not None for SSL task. Rest of the tasks target is # fetched from the batch y = batch["target"] if y is None else y y_hat = output["logits"] - self.calculate_loss(output, y, tag="test") - self.calculate_metrics(y, y_hat, tag="test") + self.calculate_loss(output, y, tag="test", sync_dist=True) + self.calculate_metrics(y, y_hat, tag="test", sync_dist=True) return y_hat, y def configure_optimizers(self): diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index 7db2b226..03b31313 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -136,11 +136,11 @@ def _setup_metrics(self): pass @abstractmethod - def calculate_loss(self, output, tag): + def calculate_loss(self, output, tag, sync_dist): pass @abstractmethod - def calculate_metrics(self, output, tag): + def calculate_metrics(self, output, tag, sync_dist): pass @abstractmethod @@ -167,15 +167,15 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): with torch.no_grad(): output = self.forward(batch) - self.calculate_loss(output, tag="valid") - self.calculate_metrics(output, tag="valid") + self.calculate_loss(output, tag="valid", sync_dist=True) + self.calculate_metrics(output, tag="valid", sync_dist=True) return output def test_step(self, batch, batch_idx): with torch.no_grad(): output = self.forward(batch) - self.calculate_loss(output, tag="test") - self.calculate_metrics(output, tag="test") + self.calculate_loss(output, tag="test", sync_dist=True) + self.calculate_metrics(output, tag="test", sync_dist=True) return output def on_validation_epoch_end(self) -> None: diff --git a/src/pytorch_tabular/ssl_models/dae/dae.py b/src/pytorch_tabular/ssl_models/dae/dae.py index 172586c4..6550018f 100644 --- a/src/pytorch_tabular/ssl_models/dae/dae.py +++ b/src/pytorch_tabular/ssl_models/dae/dae.py @@ -200,7 +200,7 @@ def forward(self, x: Dict): else: return z.features - def calculate_loss(self, output, tag): + def calculate_loss(self, output, tag, sync_dist=False): total_loss = 0 for type_, out in output.items(): if type_ == "categorical": @@ -220,6 +220,7 @@ def calculate_loss(self, output, tag): on_step=False, logger=True, prog_bar=False, + sync_dist=sync_dist, ) total_loss += loss self.log( @@ -230,10 +231,11 @@ def calculate_loss(self, output, tag): # on_step=False, logger=True, prog_bar=True, + sync_dist=sync_dist, ) return total_loss - def calculate_metrics(self, output, tag): + def calculate_metrics(self, output, tag, sync_dist=False): pass def featurize(self, x: Dict): diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index d4e7fc60..3d09bb2e 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -61,6 +61,7 @@ def __init__( self.task = task self.n = data.shape[0] self.target = target + self.index = data.index if target: self.y = data[target].astype(np.float32).values if isinstance(target, str): @@ -87,11 +88,12 @@ def data(self): data = pd.DataFrame( np.concatenate([self.categorical_X, self.continuous_X], axis=1), columns=self.categorical_cols + self.continuous_cols, + index=self.index, ) elif self.continuous_cols: - data = pd.DataFrame(self.continuous_X, columns=self.continuous_cols) + data = pd.DataFrame(self.continuous_X, columns=self.continuous_cols, index=self.index) elif self.categorical_cols: - data = pd.DataFrame(self.categorical_X, columns=self.categorical_cols) + data = pd.DataFrame(self.categorical_X, columns=self.categorical_cols, index=self.index) else: data = pd.DataFrame() for i, t in enumerate(self.target): @@ -474,6 +476,7 @@ def _cache_dataset(self): target=self.target, ) self.train = None + validation_dataset = TabularDataset( task=self.config.task, data=self.validation, @@ -484,8 +487,10 @@ def _cache_dataset(self): self.validation = None if self.cache_mode is self.CACHE_MODES.DISK: - torch.save(train_dataset, self.cache_dir / "train_dataset") - torch.save(validation_dataset, self.cache_dir / "validation_dataset") + torch.save(train_dataset, self.cache_dir / "train_dataset", pickle_protocol=self.config.pickle_protocol) + torch.save( + validation_dataset, self.cache_dir / "validation_dataset", pickle_protocol=self.config.pickle_protocol + ) elif self.cache_mode is self.CACHE_MODES.MEMORY: self.train_dataset = train_dataset self.validation_dataset = validation_dataset diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 4b3fd36a..0b34adf4 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -31,6 +31,7 @@ ) from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities.model_summary import summarize +from pytorch_lightning.utilities.rank_zero import rank_zero_only from rich import print as rich_print from rich.pretty import pprint from sklearn.base import TransformerMixin @@ -685,6 +686,8 @@ def train( "/n" + "Original Error: " + oom_handler.oom_msg ) self._is_fitted = True + if self.track_experiment and self.config.log_target == "wandb": + self.logger.experiment.unwatch(self.model) if self.verbose: logger.info("Training the model completed") if self.config.load_best: @@ -1522,6 +1525,7 @@ def add_noise(module, input, output): ) return pred_df + @rank_zero_only def load_best_model(self) -> None: """Loads the best model after training is done.""" if self.trainer.checkpoint_callback is not None: