From a3a85051dd3a4e4bc04bf2bd86223a90b48fe6d8 Mon Sep 17 00:00:00 2001 From: Manu Joseph Date: Mon, 25 Nov 2024 11:06:55 +0530 Subject: [PATCH] fix for ssl finetuning bug --- src/pytorch_tabular/tabular_datamodule.py | 2 +- src/pytorch_tabular/tabular_model.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 81f13a69..d4e7fc60 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -526,7 +526,7 @@ def setup(self, stage: Optional[str] = None) -> None: else: self.validation = self.validation.copy() # Preprocessing Train, Validation - self.train, _ = self.preprocess_data(self.train, stage="fit" if not is_ssl else "inference") + self.train, _ = self.preprocess_data(self.train, stage="inference" if is_ssl else "fit") self.validation, _ = self.preprocess_data(self.validation, stage="inference") self._fitted = True self._cache_dataset() diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 217e7b30..4b3fd36a 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -1001,13 +1001,15 @@ def create_finetune_model( logger.info("Renaming the experiment run for finetuning as" f" {config['run_name'] + '_finetuned'}") config["run_name"] = config["run_name"] + "_finetuned" + config_override = {"target": target} if target is not None else {} + config_override["task"] = task datamodule = self.datamodule.copy( train=train, validation=validation, target_transform=target_transform, train_sampler=train_sampler, seed=seed, - config_override={"target": target} if target is not None else {}, + config_override=config_override, ) model_callable = _GenericModel inferred_config = OmegaConf.structured(datamodule._inferred_config)