@@ -192,9 +192,9 @@ class DataConfig:
192192 )
193193
194194 def __post_init__ (self ):
195- assert (
196- len ( self . categorical_cols ) + len ( self . continuous_cols ) + len ( self . date_columns ) > 0
197- ), "There should be at-least one feature defined in categorical, continuous, or date columns"
195+ assert len ( self . categorical_cols ) + len ( self . continuous_cols ) + len ( self . date_columns ) > 0 , (
196+ "There should be at-least one feature defined in categorical, continuous, or date columns"
197+ )
198198 _validate_choices (self )
199199 if os .name == "nt" and self .num_workers != 0 :
200200 print ("Windows does not support num_workers > 0. Setting num_workers to 0" )
@@ -255,9 +255,9 @@ class InferredConfig:
255255
256256 def __post_init__ (self ):
257257 if self .embedding_dims is not None :
258- assert all (
259- ( isinstance ( t , Iterable ) and len ( t ) == 2 ) for t in self . embedding_dims
260- ), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
258+ assert all (( isinstance ( t , Iterable ) and len ( t ) == 2 ) for t in self . embedding_dims ), (
259+ "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
260+ )
261261 self .embedded_cat_dim = sum ([t [1 ] for t in self .embedding_dims ])
262262 else :
263263 self .embedded_cat_dim = 0
@@ -677,6 +677,9 @@ class OptimizerConfig:
677677 lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is
678678 decided based on this metric
679679
680+ lr_scheduler_interval (Optional[str]): Interval at which to step the LR Scheduler, one of "epoch"
681+ or "step". Defaults to `epoch`.
682+
680683 """
681684
682685 optimizer : str = field (
@@ -709,6 +712,11 @@ class OptimizerConfig:
709712 metadata = {"help" : "Used with ReduceLROnPlateau, where the plateau is decided based on this metric" },
710713 )
711714
715+ lr_scheduler_interval : Optional [str ] = field (
716+ default = "epoch" ,
717+ metadata = {"help" : "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`." },
718+ )
719+
712720 @staticmethod
713721 def read_from_yaml (filename : str = "config/optimizer_config.yml" ):
714722 config = _read_yaml (filename )
0 commit comments