2626from pandas import DataFrame
2727from pytorch_lightning import seed_everything
2828from pytorch_lightning .callbacks import RichProgressBar
29- from pytorch_lightning .callbacks .gradient_accumulation_scheduler import (
30- GradientAccumulationScheduler ,
31- )
29+ from pytorch_lightning .callbacks .gradient_accumulation_scheduler import GradientAccumulationScheduler
3230from pytorch_lightning .tuner .tuning import Tuner
3331from pytorch_lightning .utilities .model_summary import summarize
3432from pytorch_lightning .utilities .rank_zero import rank_zero_only
4846)
4947from pytorch_tabular .config .config import InferredConfig
5048from pytorch_tabular .models .base_model import BaseModel , _CaptumModel , _GenericModel
51- from pytorch_tabular .models .common .layers .embeddings import (
52- Embedding1dLayer ,
53- Embedding2dLayer ,
54- PreEncoded1dLayer ,
55- )
49+ from pytorch_tabular .models .common .layers .embeddings import Embedding1dLayer , Embedding2dLayer , PreEncoded1dLayer
5650from pytorch_tabular .tabular_datamodule import TabularDatamodule
5751from pytorch_tabular .utils import (
5852 OOMException ,
@@ -140,7 +134,7 @@ def __init__(
140134 optimizer_config = self ._read_parse_config (optimizer_config , OptimizerConfig )
141135 if model_config .task != "ssl" :
142136 assert data_config .target is not None , (
143- "`target` in data_config should not be None for" f" { model_config .task } task"
137+ f "`target` in data_config should not be None for { model_config .task } task"
144138 )
145139 if experiment_config is None :
146140 if self .verbose :
@@ -284,9 +278,7 @@ def _setup_experiment_tracking(self):
284278 offline = False ,
285279 )
286280 else :
287- raise NotImplementedError (
288- f"{ self .config .log_target } is not implemented. Try one of [wandb," " tensorboard]"
289- )
281+ raise NotImplementedError (f"{ self .config .log_target } is not implemented. Try one of [wandb, tensorboard]" )
290282
291283 def _prepare_callbacks (self , callbacks = None ) -> List :
292284 """Prepares the necesary callbacks to the Trainer based on the configuration.
@@ -374,11 +366,9 @@ def _check_and_set_target_transform(self, target_transform):
374366 elif isinstance (target_transform , TransformerMixin ):
375367 pass
376368 else :
377- raise ValueError (
378- "`target_transform` should wither be an sklearn Transformer or a" " tuple of callables."
379- )
369+ raise ValueError ("`target_transform` should wither be an sklearn Transformer or a tuple of callables." )
380370 if self .config .task == "classification" and target_transform is not None :
381- logger .warning ("For classification task, target transform is not used. Ignoring the" " parameter" )
371+ logger .warning ("For classification task, target transform is not used. Ignoring the parameter" )
382372 target_transform = None
383373 return target_transform
384374
@@ -674,6 +664,8 @@ def train(
674664 self .model .reset_weights ()
675665 # Parameters in models needs to be initialized again after LR find
676666 self .model .data_aware_initialization (self .datamodule )
667+ # Update the Trainer to use the suggested LR
668+ self ._prepare_for_training (self .model , self .datamodule , callbacks , max_epochs , min_epochs )
677669 self .model .train ()
678670 if self .verbose :
679671 logger .info ("Training Started" )
@@ -772,12 +764,12 @@ def fit(
772764
773765 """
774766 assert self .config .task != "ssl" , (
775- "`fit` is not valid for SSL task. Please use `pretrain` for" " semi-supervised learning"
767+ "`fit` is not valid for SSL task. Please use `pretrain` for semi-supervised learning"
776768 )
777769 if metrics is not None :
778- assert len (metrics ) == len (
779- metrics_prob_inputs or []
780- ), "The length of `metrics` and `metrics_prob_inputs` should be equal"
770+ assert len (metrics ) == len (metrics_prob_inputs or []), (
771+ "The length of `metrics` and ` metrics_prob_inputs` should be equal"
772+ )
781773 seed = seed or self .config .seed
782774 if seed :
783775 seed_everything (seed )
@@ -855,7 +847,7 @@ def pretrain(
855847
856848 """
857849 assert self .config .task == "ssl" , (
858- f"`pretrain` is not valid for { self .config .task } task. Please use `fit`" " instead."
850+ f"`pretrain` is not valid for { self .config .task } task. Please use `fit` instead."
859851 )
860852 seed = seed or self .config .seed
861853 if seed :
@@ -976,9 +968,9 @@ def create_finetune_model(
976968 config = self .config
977969 optimizer_params = optimizer_params or {}
978970 if target is None :
979- assert (
980- hasattr ( config , " target" ) and config . target is not None
981- ), "`target` cannot be None if it was not set in the initial `DataConfig`"
971+ assert hasattr ( config , "target" ) and config . target is not None , (
972+ "` target` cannot be None if it was not set in the initial `DataConfig`"
973+ )
982974 else :
983975 assert isinstance (target , list ), "`target` should be a list of strings"
984976 config .target = target
@@ -1001,7 +993,7 @@ def create_finetune_model(
1001993 if self .track_experiment :
1002994 # Renaming the experiment run so that a different log is created for finetuning
1003995 if self .verbose :
1004- logger .info ("Renaming the experiment run for finetuning as" f" { config ['run_name' ] + '_finetuned' } " )
996+ logger .info (f "Renaming the experiment run for finetuning as { config ['run_name' ] + '_finetuned' } " )
1005997 config ["run_name" ] = config ["run_name" ] + "_finetuned"
1006998
1007999 config_override = {"target" : target } if target is not None else {}
@@ -1106,7 +1098,7 @@ def finetune(
11061098
11071099 """
11081100 assert self ._is_finetune_model , (
1109- "finetune() can only be called on a finetune model created using" " `TabularModel.create_finetune_model()`"
1101+ "finetune() can only be called on a finetune model created using `TabularModel.create_finetune_model()`"
11101102 )
11111103 seed_everything (self .config .seed )
11121104 if freeze_backbone :
@@ -1294,15 +1286,15 @@ def _format_predicitons(
12941286 )
12951287 if is_probabilistic :
12961288 for j , q in enumerate (quantiles ):
1297- col_ = f"{ target_col } _q{ int (q * 100 )} "
1289+ col_ = f"{ target_col } _q{ int (q * 100 )} "
12981290 pred_df [col_ ] = self .datamodule .target_transforms [i ].inverse_transform (
12991291 quantile_predictions [:, j , i ].reshape (- 1 , 1 )
13001292 )
13011293 else :
13021294 pred_df [f"{ target_col } _prediction" ] = point_predictions [:, i ]
13031295 if is_probabilistic :
13041296 for j , q in enumerate (quantiles ):
1305- pred_df [f"{ target_col } _q{ int (q * 100 )} " ] = quantile_predictions [:, j , i ].reshape (- 1 , 1 )
1297+ pred_df [f"{ target_col } _q{ int (q * 100 )} " ] = quantile_predictions [:, j , i ].reshape (- 1 , 1 )
13061298
13071299 elif self .config .task == "classification" :
13081300 start_index = 0
@@ -1483,7 +1475,7 @@ def predict(
14831475 "min" ,
14841476 "max" ,
14851477 "hard_voting" ,
1486- ], "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
1478+ ], "aggregate should be one of 'mean', 'median', 'min', 'max', or 'hard_voting'"
14871479 if self .config .task == "regression" :
14881480 assert aggregate_tta != "hard_voting" , "hard_voting is only available for classification"
14891481
@@ -1538,11 +1530,9 @@ def load_best_model(self) -> None:
15381530 ckpt = pl_load (ckpt_path , map_location = lambda storage , loc : storage )
15391531 self .model .load_state_dict (ckpt ["state_dict" ])
15401532 else :
1541- logger .warning ("No best model available to load. Did you run it more than 1" " epoch?..." )
1533+ logger .warning ("No best model available to load. Did you run it more than 1 epoch?..." )
15421534 else :
1543- logger .warning (
1544- "No best model available to load. Checkpoint Callback needs to be" " enabled for this to work"
1545- )
1535+ logger .warning ("No best model available to load. Checkpoint Callback needs to be enabled for this to work" )
15461536
15471537 def save_datamodule (self , dir : str , inference_only : bool = False ) -> None :
15481538 """Saves the datamodule in the specified directory.
@@ -1707,7 +1697,7 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str:
17071697 summary_str += "Config\n "
17081698 summary_str += "-" * 100 + "\n "
17091699 summary_str += pformat (self .config .__dict__ ["_content" ], indent = 4 , width = 80 , compact = True )
1710- summary_str += "\n Full Model Summary once model has been " " initialized or passed in as an argument"
1700+ summary_str += "\n Full Model Summary once model has been initialized or passed in as an argument"
17111701 return summary_str
17121702
17131703 def __str__ (self ) -> str :
@@ -1936,9 +1926,7 @@ def _prepare_baselines_captum(
19361926 else :
19371927 baselines = baselines .mean (dim = 0 , keepdim = True )
19381928 else :
1939- raise ValueError (
1940- "Invalid value for `baselines`. Please refer to the documentation" " for more details."
1941- )
1929+ raise ValueError ("Invalid value for `baselines`. Please refer to the documentation for more details." )
19421930 return baselines
19431931
19441932 def _handle_categorical_embeddings_attributions (
@@ -2061,9 +2049,7 @@ def explain(
20612049 hasattr (self .model .hparams , "embedding_dims" ) and self .model .hparams .embedding_dims is not None
20622050 )
20632051 if (not is_embedding1d ) and (not is_embedding2d ):
2064- raise NotImplementedError (
2065- "Attributions are not implemented for models with this type of" " embedding layer"
2066- )
2052+ raise NotImplementedError ("Attributions are not implemented for models with this type of embedding layer" )
20672053 test_dl = self .datamodule .prepare_inference_dataloader (data )
20682054 self .model .eval ()
20692055 # prepare import for Captum
@@ -2095,7 +2081,7 @@ def explain(
20952081 "Something went wrong. The number of features in the attributions"
20962082 f" ({ attributions .shape [1 ]} ) does not match the number of features in"
20972083 " the model"
2098- f" ({ self .model .hparams .continuous_dim + self .model .hparams .categorical_dim } )"
2084+ f" ({ self .model .hparams .continuous_dim + self .model .hparams .categorical_dim } )"
20992085 )
21002086 return pd .DataFrame (
21012087 attributions .detach ().cpu ().numpy (),
@@ -2215,7 +2201,7 @@ def cross_validate(
22152201 oof_preds = []
22162202 for fold , (train_idx , val_idx ) in it :
22172203 if verbose :
2218- logger .info (f"Running Fold { fold + 1 } /{ cv .get_n_splits ()} " )
2204+ logger .info (f"Running Fold { fold + 1 } /{ cv .get_n_splits ()} " )
22192205 # train_fold = train.iloc[train_idx]
22202206 # val_fold = train.iloc[val_idx]
22212207 if reset_datamodule :
@@ -2247,7 +2233,7 @@ def cross_validate(
22472233 result = self .evaluate (train .iloc [val_idx ], verbose = False )
22482234 cv_metrics .append (result [0 ][metric ])
22492235 if verbose :
2250- logger .info (f"Fold { fold + 1 } /{ cv .get_n_splits ()} score: { cv_metrics [- 1 ]} " )
2236+ logger .info (f"Fold { fold + 1 } /{ cv .get_n_splits ()} score: { cv_metrics [- 1 ]} " )
22512237 self .model .reset_weights ()
22522238 return cv_metrics , oof_preds
22532239
@@ -2376,7 +2362,7 @@ def bagging_predict(
23762362 ], "Bagging is only available for classification and regression"
23772363 if not callable (aggregate ):
23782364 assert aggregate in ["mean" , "median" , "min" , "max" , "hard_voting" ], (
2379- "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
2365+ "aggregate should be one of 'mean', 'median', 'min', 'max', or 'hard_voting'"
23802366 )
23812367 if self .config .task == "regression" :
23822368 assert aggregate != "hard_voting" , "hard_voting is only available for classification"
@@ -2387,7 +2373,7 @@ def bagging_predict(
23872373 model = None
23882374 for fold , (train_idx , val_idx ) in enumerate (cv .split (train , y = train [self .config .target ], groups = groups )):
23892375 if verbose :
2390- logger .info (f"Running Fold { fold + 1 } /{ cv .get_n_splits ()} " )
2376+ logger .info (f"Running Fold { fold + 1 } /{ cv .get_n_splits ()} " )
23912377 train_fold = train .iloc [train_idx ]
23922378 val_fold = train .iloc [val_idx ]
23932379 if reset_datamodule :
@@ -2412,7 +2398,7 @@ def bagging_predict(
24122398 elif self .config .task == "regression" :
24132399 pred_prob_l .append (fold_preds .values )
24142400 if verbose :
2415- logger .info (f"Fold { fold + 1 } /{ cv .get_n_splits ()} prediction done" )
2401+ logger .info (f"Fold { fold + 1 } /{ cv .get_n_splits ()} prediction done" )
24162402 self .model .reset_weights ()
24172403 pred_df = self ._combine_predictions (pred_prob_l , pred_idx , aggregate , weights )
24182404 if return_raw_predictions :
0 commit comments