Skip to content

Commit 7970abf

Browse files
sorenmacbethpre-commit-ci[bot]manujosephv
authored
Optimizer lr scheduler interval (#545)
* Make tensor dtypes `np.float32` for MPS devices numpy defaults to numpy.float64 when they should be numpy.float32 This caused training to fail on MPS devices but it works on my M1 with this. * Add lr scheduler interval config * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Manu Joseph V <manujosephv@gmail.com>
1 parent 452aa36 commit 7970abf

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,9 @@ class OptimizerConfig:
677677
lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is
678678
decided based on this metric
679679
680+
lr_scheduler_interval (Optional[str]): Interval at which to step the LR Scheduler, one of "epoch"
681+
or "step". Defaults to `epoch`.
682+
680683
"""
681684

682685
optimizer: str = field(
@@ -709,6 +712,11 @@ class OptimizerConfig:
709712
metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
710713
)
711714

715+
lr_scheduler_interval: Optional[str] = field(
716+
default="epoch",
717+
metadata={"help": "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`."},
718+
)
719+
712720
@staticmethod
713721
def read_from_yaml(filename: str = "config/optimizer_config.yml"):
714722
config = _read_yaml(filename)

src/pytorch_tabular/models/base_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,11 @@ def configure_optimizers(self):
588588
}
589589
return {
590590
"optimizer": opt,
591-
"lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
592-
"monitor": self.hparams.lr_scheduler_monitor_metric,
591+
"lr_scheduler": {
592+
"scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
593+
"monitor": self.hparams.lr_scheduler_monitor_metric,
594+
"interval": self.hparams.lr_scheduler_interval,
595+
},
593596
}
594597
else:
595598
return opt

src/pytorch_tabular/ssl_models/base_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test_step(self, batch, batch_idx):
181181
def on_validation_epoch_end(self) -> None:
182182
if hasattr(self.hparams, "log_logits") and self.hparams.log_logits:
183183
warnings.warn(
184-
"Logging Logits is disabled for SSL tasks. Set `log_logits` to False" " to turn off this warning"
184+
"Logging Logits is disabled for SSL tasks. Set `log_logits` to False to turn off this warning"
185185
)
186186
super().on_validation_epoch_end()
187187

@@ -219,8 +219,11 @@ def configure_optimizers(self):
219219
}
220220
return {
221221
"optimizer": opt,
222-
"lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
223-
"monitor": self.hparams.lr_scheduler_monitor_metric,
222+
"lr_scheduler": {
223+
"scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
224+
"monitor": self.hparams.lr_scheduler_monitor_metric,
225+
"interval": self.hparams.lr_scheduler_interval,
226+
},
224227
}
225228
else:
226229
return opt

0 commit comments

Comments
 (0)