@@ -26,12 +26,15 @@ def _read_yaml(filename):
2626 "tag:yaml.org,2002:float" ,
2727 re .compile (
2828 """^(?:
29- [-+]?(?:[0-9][0-9_]*)\\ .[0-9_]*(?:[eE][-+]?[0-9]+)?
30- |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
31- |\\ .[0-9_]+(?:[eE][-+][0-9]+)?
32- |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\ .[0-9_]*
33- |[-+]?\\ .(?:inf|Inf|INF)
34- |\\ .(?:nan|NaN|NAN))$""" ,
29+
30+ [-+]?(?:[0-9][0-9_]*)\\ .[0-9_]*(?:[eE][-+]?[0-9]+)?
31+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
32+ |\\ .[0-9_]+(?:[eE][-+][0-9]+)?
33+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\ .[0-9_]*
34+ |[-+]?\\ .(?:inf|Inf|INF)
35+ |\\ .(?:nan|NaN|NAN))$
36+
37+ """ ,
3538 re .X ,
3639 ),
3740 list ("-+0123456789." ),
@@ -192,9 +195,9 @@ class DataConfig:
192195 )
193196
194197 def __post_init__ (self ):
195- assert (
196- len ( self . categorical_cols ) + len ( self . continuous_cols ) + len ( self . date_columns ) > 0
197- ), "There should be at-least one feature defined in categorical, continuous, or date columns"
198+ assert len ( self . categorical_cols ) + len ( self . continuous_cols ) + len ( self . date_columns ) > 0 , (
199+ "There should be at-least one feature defined in categorical, continuous, or date columns"
200+ )
198201 _validate_choices (self )
199202 if os .name == "nt" and self .num_workers != 0 :
200203 print ("Windows does not support num_workers > 0. Setting num_workers to 0" )
@@ -255,9 +258,9 @@ class InferredConfig:
255258
256259 def __post_init__ (self ):
257260 if self .embedding_dims is not None :
258- assert all (
259- ( isinstance ( t , Iterable ) and len ( t ) == 2 ) for t in self . embedding_dims
260- ), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
261+ assert all (( isinstance ( t , Iterable ) and len ( t ) == 2 ) for t in self . embedding_dims ), (
262+ "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
263+ )
261264 self .embedded_cat_dim = sum ([t [1 ] for t in self .embedding_dims ])
262265 else :
263266 self .embedded_cat_dim = 0
@@ -581,24 +584,25 @@ def __post_init__(self):
581584
582585@dataclass
583586class ExperimentConfig :
584- """Experiment configuration. Experiment Tracking with WandB and Tensorboard.
587+ """Experiment configuration.
585588
586- Args:
587- project_name (str): The name of the project under which all runs will be logged. For Tensorboard
588- this defines the folder under which the logs will be saved and for W&B it defines the project name
589+ Experiment Tracking with WandB and Tensorboard.
590+ Args:
591+ project_name (str): The name of the project under which all runs will be logged. For Tensorboard
592+ this defines the folder under which the logs will be saved and for W&B it defines the project name
589593
590- run_name (Optional[str]): The name of the run; a specific identifier to recognize the run. If left
591- blank, will be assigned an auto-generated name
594+ run_name (Optional[str]): The name of the run; a specific identifier to recognize the run. If left
595+ blank, will be assigned an auto-generated name
592596
593- exp_watch (Optional[str]): The level of logging required. Can be `gradients`, `parameters`, `all`
594- or `None`. Defaults to None. Choices are: [`gradients`,`parameters`,`all`,`None`].
597+ exp_watch (Optional[str]): The level of logging required. Can be `gradients`, `parameters`, `all`
598+ or `None`. Defaults to None. Choices are: [`gradients`,`parameters`,`all`,`None`].
595599
596- log_target (str): Determines where logging happens - Tensorboard or W&B. Choices are:
597- [`wandb`,`tensorboard`].
600+ log_target (str): Determines where logging happens - Tensorboard or W&B. Choices are:
601+ [`wandb`,`tensorboard`].
598602
599- log_logits (bool): Turn this on to log the logits as a histogram in W&B
603+ log_logits (bool): Turn this on to log the logits as a histogram in W&B
600604
601- exp_log_freq (int): step count between logging of gradients and parameters.
605+ exp_log_freq (int): step count between logging of gradients and parameters.
602606
603607 """
604608
@@ -730,8 +734,8 @@ def __init__(
730734 self ,
731735 exp_version_manager : str = ".pt_tmp/exp_version_manager.yml" ,
732736 ) -> None :
733- """The manages the versions of the experiments based on the name. It is a simple dictionary(yaml) based lookup.
734- Primary purpose is to avoid overwriting of saved models while running the training without changing the
737+ """The manages the versions of the experiments based on the name. Primary purpose is to avoid overwriting of
738+ saved models while running the training without changing the It is a simple dictionary(yaml) based lookup.
735739 experiment name.
736740
737741 Args:
0 commit comments