Skip to content

Commit a196f80

Browse files
committed
Add lr scheduler interval config
1 parent 6cc6da1 commit a196f80

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ class DataConfig:
192192
)
193193

194194
def __post_init__(self):
195-
assert (
196-
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
197-
), "There should be at-least one feature defined in categorical, continuous, or date columns"
195+
assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, (
196+
"There should be at-least one feature defined in categorical, continuous, or date columns"
197+
)
198198
_validate_choices(self)
199199
if os.name == "nt" and self.num_workers != 0:
200200
print("Windows does not support num_workers > 0. Setting num_workers to 0")
@@ -255,9 +255,9 @@ class InferredConfig:
255255

256256
def __post_init__(self):
257257
if self.embedding_dims is not None:
258-
assert all(
259-
(isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims
260-
), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
258+
assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), (
259+
"embedding_dims must be a list of tuples (cardinality, embedding_dim)"
260+
)
261261
self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims])
262262
else:
263263
self.embedded_cat_dim = 0
@@ -677,6 +677,9 @@ class OptimizerConfig:
677677
lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is
678678
decided based on this metric
679679
680+
lr_scheduler_interval (Optional[str]): Interval at which to step the LR Scheduler, one of "epoch"
681+
or "step". Defaults to `epoch`.
682+
680683
"""
681684

682685
optimizer: str = field(
@@ -709,6 +712,11 @@ class OptimizerConfig:
709712
metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
710713
)
711714

715+
lr_scheduler_interval: Optional[str] = field(
716+
default="epoch",
717+
metadata={"help": "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`."},
718+
)
719+
712720
@staticmethod
713721
def read_from_yaml(filename: str = "config/optimizer_config.yml"):
714722
config = _read_yaml(filename)

src/pytorch_tabular/models/base_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,11 @@ def configure_optimizers(self):
588588
}
589589
return {
590590
"optimizer": opt,
591-
"lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
592-
"monitor": self.hparams.lr_scheduler_monitor_metric,
591+
"lr_scheduler": {
592+
"scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
593+
"monitor": self.hparams.lr_scheduler_monitor_metric,
594+
"interval": self.hparams.lr_scheduler_interval,
595+
},
593596
}
594597
else:
595598
return opt

src/pytorch_tabular/ssl_models/base_model.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def __init__(
8585
self._setup_metrics()
8686

8787
def _setup_encoder_decoder(self, encoder, encoder_config, decoder, decoder_config, inferred_config):
88-
assert (encoder is not None) or (
89-
encoder_config is not None
90-
), "Either encoder or encoder_config must be provided"
88+
assert (encoder is not None) or (encoder_config is not None), (
89+
"Either encoder or encoder_config must be provided"
90+
)
9191
# assert (decoder is not None) or (decoder_config is not None),
9292
# "Either decoder or decoder_config must be provided"
9393
if encoder is not None:
@@ -181,7 +181,7 @@ def test_step(self, batch, batch_idx):
181181
def on_validation_epoch_end(self) -> None:
182182
if hasattr(self.hparams, "log_logits") and self.hparams.log_logits:
183183
warnings.warn(
184-
"Logging Logits is disabled for SSL tasks. Set `log_logits` to False" " to turn off this warning"
184+
"Logging Logits is disabled for SSL tasks. Set `log_logits` to False to turn off this warning"
185185
)
186186
super().on_validation_epoch_end()
187187

@@ -219,8 +219,11 @@ def configure_optimizers(self):
219219
}
220220
return {
221221
"optimizer": opt,
222-
"lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
223-
"monitor": self.hparams.lr_scheduler_monitor_metric,
222+
"lr_scheduler": {
223+
"scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
224+
"monitor": self.hparams.lr_scheduler_monitor_metric,
225+
"interval": self.hparams.lr_scheduler_interval,
226+
},
224227
}
225228
else:
226229
return opt

0 commit comments

Comments
 (0)