Skip to content

Commit 74bae26

Browse files
committed
update Trainer to use the suggested LR after auto_lr_find
1 parent 6cc6da1 commit 74bae26

File tree

2 files changed

+33
-47
lines changed

2 files changed

+33
-47
lines changed

src/pytorch_tabular/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
__author__ = """Manu Joseph"""
44
__email__ = "manujosephv@gmail.com"
5-
__version__ = "1.1.0"
5+
__version__ = "1.1.1"
66

77
from . import models, ssl_models
88
from .categorical_encoders import CategoricalEmbeddingTransformer

src/pytorch_tabular/tabular_model.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
from pandas import DataFrame
2727
from pytorch_lightning import seed_everything
2828
from pytorch_lightning.callbacks import RichProgressBar
29-
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import (
30-
GradientAccumulationScheduler,
31-
)
29+
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
3230
from pytorch_lightning.tuner.tuning import Tuner
3331
from pytorch_lightning.utilities.model_summary import summarize
3432
from pytorch_lightning.utilities.rank_zero import rank_zero_only
@@ -48,11 +46,7 @@
4846
)
4947
from pytorch_tabular.config.config import InferredConfig
5048
from pytorch_tabular.models.base_model import BaseModel, _CaptumModel, _GenericModel
51-
from pytorch_tabular.models.common.layers.embeddings import (
52-
Embedding1dLayer,
53-
Embedding2dLayer,
54-
PreEncoded1dLayer,
55-
)
49+
from pytorch_tabular.models.common.layers.embeddings import Embedding1dLayer, Embedding2dLayer, PreEncoded1dLayer
5650
from pytorch_tabular.tabular_datamodule import TabularDatamodule
5751
from pytorch_tabular.utils import (
5852
OOMException,
@@ -140,7 +134,7 @@ def __init__(
140134
optimizer_config = self._read_parse_config(optimizer_config, OptimizerConfig)
141135
if model_config.task != "ssl":
142136
assert data_config.target is not None, (
143-
"`target` in data_config should not be None for" f" {model_config.task} task"
137+
f"`target` in data_config should not be None for {model_config.task} task"
144138
)
145139
if experiment_config is None:
146140
if self.verbose:
@@ -284,9 +278,7 @@ def _setup_experiment_tracking(self):
284278
offline=False,
285279
)
286280
else:
287-
raise NotImplementedError(
288-
f"{self.config.log_target} is not implemented. Try one of [wandb," " tensorboard]"
289-
)
281+
raise NotImplementedError(f"{self.config.log_target} is not implemented. Try one of [wandb, tensorboard]")
290282

291283
def _prepare_callbacks(self, callbacks=None) -> List:
292284
"""Prepares the necesary callbacks to the Trainer based on the configuration.
@@ -374,11 +366,9 @@ def _check_and_set_target_transform(self, target_transform):
374366
elif isinstance(target_transform, TransformerMixin):
375367
pass
376368
else:
377-
raise ValueError(
378-
"`target_transform` should wither be an sklearn Transformer or a" " tuple of callables."
379-
)
369+
raise ValueError("`target_transform` should wither be an sklearn Transformer or a tuple of callables.")
380370
if self.config.task == "classification" and target_transform is not None:
381-
logger.warning("For classification task, target transform is not used. Ignoring the" " parameter")
371+
logger.warning("For classification task, target transform is not used. Ignoring the parameter")
382372
target_transform = None
383373
return target_transform
384374

@@ -674,6 +664,8 @@ def train(
674664
self.model.reset_weights()
675665
# Parameters in models needs to be initialized again after LR find
676666
self.model.data_aware_initialization(self.datamodule)
667+
# Update the Trainer to use the suggested LR
668+
self._prepare_for_training(self.model, self.datamodule, callbacks, max_epochs, min_epochs)
677669
self.model.train()
678670
if self.verbose:
679671
logger.info("Training Started")
@@ -772,12 +764,12 @@ def fit(
772764
773765
"""
774766
assert self.config.task != "ssl", (
775-
"`fit` is not valid for SSL task. Please use `pretrain` for" " semi-supervised learning"
767+
"`fit` is not valid for SSL task. Please use `pretrain` for semi-supervised learning"
776768
)
777769
if metrics is not None:
778-
assert len(metrics) == len(
779-
metrics_prob_inputs or []
780-
), "The length of `metrics` and `metrics_prob_inputs` should be equal"
770+
assert len(metrics) == len(metrics_prob_inputs or []), (
771+
"The length of `metrics` and `metrics_prob_inputs` should be equal"
772+
)
781773
seed = seed or self.config.seed
782774
if seed:
783775
seed_everything(seed)
@@ -855,7 +847,7 @@ def pretrain(
855847
856848
"""
857849
assert self.config.task == "ssl", (
858-
f"`pretrain` is not valid for {self.config.task} task. Please use `fit`" " instead."
850+
f"`pretrain` is not valid for {self.config.task} task. Please use `fit` instead."
859851
)
860852
seed = seed or self.config.seed
861853
if seed:
@@ -976,9 +968,9 @@ def create_finetune_model(
976968
config = self.config
977969
optimizer_params = optimizer_params or {}
978970
if target is None:
979-
assert (
980-
hasattr(config, "target") and config.target is not None
981-
), "`target` cannot be None if it was not set in the initial `DataConfig`"
971+
assert hasattr(config, "target") and config.target is not None, (
972+
"`target` cannot be None if it was not set in the initial `DataConfig`"
973+
)
982974
else:
983975
assert isinstance(target, list), "`target` should be a list of strings"
984976
config.target = target
@@ -1001,7 +993,7 @@ def create_finetune_model(
1001993
if self.track_experiment:
1002994
# Renaming the experiment run so that a different log is created for finetuning
1003995
if self.verbose:
1004-
logger.info("Renaming the experiment run for finetuning as" f" {config['run_name'] + '_finetuned'}")
996+
logger.info(f"Renaming the experiment run for finetuning as {config['run_name'] + '_finetuned'}")
1005997
config["run_name"] = config["run_name"] + "_finetuned"
1006998

1007999
config_override = {"target": target} if target is not None else {}
@@ -1106,7 +1098,7 @@ def finetune(
11061098
11071099
"""
11081100
assert self._is_finetune_model, (
1109-
"finetune() can only be called on a finetune model created using" " `TabularModel.create_finetune_model()`"
1101+
"finetune() can only be called on a finetune model created using `TabularModel.create_finetune_model()`"
11101102
)
11111103
seed_everything(self.config.seed)
11121104
if freeze_backbone:
@@ -1294,15 +1286,15 @@ def _format_predicitons(
12941286
)
12951287
if is_probabilistic:
12961288
for j, q in enumerate(quantiles):
1297-
col_ = f"{target_col}_q{int(q*100)}"
1289+
col_ = f"{target_col}_q{int(q * 100)}"
12981290
pred_df[col_] = self.datamodule.target_transforms[i].inverse_transform(
12991291
quantile_predictions[:, j, i].reshape(-1, 1)
13001292
)
13011293
else:
13021294
pred_df[f"{target_col}_prediction"] = point_predictions[:, i]
13031295
if is_probabilistic:
13041296
for j, q in enumerate(quantiles):
1305-
pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1)
1297+
pred_df[f"{target_col}_q{int(q * 100)}"] = quantile_predictions[:, j, i].reshape(-1, 1)
13061298

13071299
elif self.config.task == "classification":
13081300
start_index = 0
@@ -1483,7 +1475,7 @@ def predict(
14831475
"min",
14841476
"max",
14851477
"hard_voting",
1486-
], "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
1478+
], "aggregate should be one of 'mean', 'median', 'min', 'max', or 'hard_voting'"
14871479
if self.config.task == "regression":
14881480
assert aggregate_tta != "hard_voting", "hard_voting is only available for classification"
14891481

@@ -1538,11 +1530,9 @@ def load_best_model(self) -> None:
15381530
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
15391531
self.model.load_state_dict(ckpt["state_dict"])
15401532
else:
1541-
logger.warning("No best model available to load. Did you run it more than 1" " epoch?...")
1533+
logger.warning("No best model available to load. Did you run it more than 1 epoch?...")
15421534
else:
1543-
logger.warning(
1544-
"No best model available to load. Checkpoint Callback needs to be" " enabled for this to work"
1545-
)
1535+
logger.warning("No best model available to load. Checkpoint Callback needs to be enabled for this to work")
15461536

15471537
def save_datamodule(self, dir: str, inference_only: bool = False) -> None:
15481538
"""Saves the datamodule in the specified directory.
@@ -1707,7 +1697,7 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str:
17071697
summary_str += "Config\n"
17081698
summary_str += "-" * 100 + "\n"
17091699
summary_str += pformat(self.config.__dict__["_content"], indent=4, width=80, compact=True)
1710-
summary_str += "\nFull Model Summary once model has been " "initialized or passed in as an argument"
1700+
summary_str += "\nFull Model Summary once model has been initialized or passed in as an argument"
17111701
return summary_str
17121702

17131703
def __str__(self) -> str:
@@ -1936,9 +1926,7 @@ def _prepare_baselines_captum(
19361926
else:
19371927
baselines = baselines.mean(dim=0, keepdim=True)
19381928
else:
1939-
raise ValueError(
1940-
"Invalid value for `baselines`. Please refer to the documentation" " for more details."
1941-
)
1929+
raise ValueError("Invalid value for `baselines`. Please refer to the documentation for more details.")
19421930
return baselines
19431931

19441932
def _handle_categorical_embeddings_attributions(
@@ -2061,9 +2049,7 @@ def explain(
20612049
hasattr(self.model.hparams, "embedding_dims") and self.model.hparams.embedding_dims is not None
20622050
)
20632051
if (not is_embedding1d) and (not is_embedding2d):
2064-
raise NotImplementedError(
2065-
"Attributions are not implemented for models with this type of" " embedding layer"
2066-
)
2052+
raise NotImplementedError("Attributions are not implemented for models with this type of embedding layer")
20672053
test_dl = self.datamodule.prepare_inference_dataloader(data)
20682054
self.model.eval()
20692055
# prepare import for Captum
@@ -2095,7 +2081,7 @@ def explain(
20952081
"Something went wrong. The number of features in the attributions"
20962082
f" ({attributions.shape[1]}) does not match the number of features in"
20972083
" the model"
2098-
f" ({self.model.hparams.continuous_dim+self.model.hparams.categorical_dim})"
2084+
f" ({self.model.hparams.continuous_dim + self.model.hparams.categorical_dim})"
20992085
)
21002086
return pd.DataFrame(
21012087
attributions.detach().cpu().numpy(),
@@ -2215,7 +2201,7 @@ def cross_validate(
22152201
oof_preds = []
22162202
for fold, (train_idx, val_idx) in it:
22172203
if verbose:
2218-
logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
2204+
logger.info(f"Running Fold {fold + 1}/{cv.get_n_splits()}")
22192205
# train_fold = train.iloc[train_idx]
22202206
# val_fold = train.iloc[val_idx]
22212207
if reset_datamodule:
@@ -2247,7 +2233,7 @@ def cross_validate(
22472233
result = self.evaluate(train.iloc[val_idx], verbose=False)
22482234
cv_metrics.append(result[0][metric])
22492235
if verbose:
2250-
logger.info(f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}")
2236+
logger.info(f"Fold {fold + 1}/{cv.get_n_splits()} score: {cv_metrics[-1]}")
22512237
self.model.reset_weights()
22522238
return cv_metrics, oof_preds
22532239

@@ -2376,7 +2362,7 @@ def bagging_predict(
23762362
], "Bagging is only available for classification and regression"
23772363
if not callable(aggregate):
23782364
assert aggregate in ["mean", "median", "min", "max", "hard_voting"], (
2379-
"aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
2365+
"aggregate should be one of 'mean', 'median', 'min', 'max', or 'hard_voting'"
23802366
)
23812367
if self.config.task == "regression":
23822368
assert aggregate != "hard_voting", "hard_voting is only available for classification"
@@ -2387,7 +2373,7 @@ def bagging_predict(
23872373
model = None
23882374
for fold, (train_idx, val_idx) in enumerate(cv.split(train, y=train[self.config.target], groups=groups)):
23892375
if verbose:
2390-
logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
2376+
logger.info(f"Running Fold {fold + 1}/{cv.get_n_splits()}")
23912377
train_fold = train.iloc[train_idx]
23922378
val_fold = train.iloc[val_idx]
23932379
if reset_datamodule:
@@ -2412,7 +2398,7 @@ def bagging_predict(
24122398
elif self.config.task == "regression":
24132399
pred_prob_l.append(fold_preds.values)
24142400
if verbose:
2415-
logger.info(f"Fold {fold+1}/{cv.get_n_splits()} prediction done")
2401+
logger.info(f"Fold {fold + 1}/{cv.get_n_splits()} prediction done")
24162402
self.model.reset_weights()
24172403
pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate, weights)
24182404
if return_raw_predictions:

0 commit comments

Comments
 (0)