Skip to content

Commit 40d8772

Browse files
ProgramadorArtificialpre-commit-ci[bot]manujosephv
authored
Add tuner return best model (#374)
* Add to tuner return best model * Fix bug with progress bar in tuner * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove datamodule before deepcopy and change tuner output * Update documentation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Manu Joseph V <manujosephv@gmail.com>
1 parent beb29d8 commit 40d8772

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

docs/tutorials/10-Hyperparameter Tuning.ipynb

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,11 +1396,12 @@
13961396
"cell_type": "markdown",
13971397
"metadata": {},
13981398
"source": [
1399-
"Result is a namedtuple with trials_df, best_params, and best_score\\\n",
1399+
"Result is a namedtuple with trials_df, best_params, best_score and best_model\\\n",
14001400
"\n",
14011401
"- trials_df: A dataframe with all the hyperparameter combinations and their corresponding scores\n",
14021402
"- best_params: The best hyperparameter combination\n",
1403-
"- best_score: The best score"
1403+
"- best_score: The best score\n",
1404+
"- best_model: If return_best_model is True, return best_model otherwise return None"
14041405
]
14051406
},
14061407
{
@@ -1895,11 +1896,12 @@
18951896
"cell_type": "markdown",
18961897
"metadata": {},
18971898
"source": [
1898-
"Result is a namedtuple with trials_df, best_params, and best_score\\\n",
1899+
"Result is a namedtuple with trials_df, best_params, best_score and best_model\\\n",
18991900
"\n",
19001901
"- trials_df: A dataframe with all the hyperparameter combinations and their corresponding scores\n",
19011902
"- best_params: The best hyperparameter combination\n",
1902-
"- best_score: The best score"
1903+
"- best_score: The best score\n",
1904+
"- best_model: If return_best_model is True, return best_model otherwise return None"
19031905
]
19041906
},
19051907
{

src/pytorch_tabular/tabular_model_sweep.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ def model_sweep(
232232
verbose (bool, optional): If True, will print the progress. Defaults to True.
233233
234234
suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True.
235+
236+
Returns:
237+
results: Training results.
238+
239+
best_model: If return_best_model is True, return best_model otherwise return None.
235240
"""
236241
_validate_args(
237242
task=task,
@@ -386,4 +391,4 @@ def _init_tabular_model(m):
386391
best_model.datamodule = datamodule
387392
return results, best_model
388393
else:
389-
return results
394+
return results, None

src/pytorch_tabular/tabular_model_tuner.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class TabularModelTuner:
3535
"""
3636

3737
ALLOWABLE_STRATEGIES = ["grid_search", "random_search"]
38-
OUTPUT = namedtuple("OUTPUT", ["trials_df", "best_params", "best_score"])
38+
OUTPUT = namedtuple("OUTPUT", ["trials_df", "best_params", "best_score", "best_model"])
3939

4040
def __init__(
4141
self,
@@ -88,6 +88,8 @@ def __init__(
8888
if trainer_config.fast_dev_run:
8989
warnings.warn("fast_dev_run is turned on. Tuning results won't be accurate.")
9090
if trainer_config.progress_bar != "none":
91+
# If config and tuner have progress bar enabled, it will result in a bug within the library (rich.progress)
92+
trainer_config.progress_bar = "none"
9193
warnings.warn("Turning off progress bar. Set progress_bar='none' in TrainerConfig to disable this warning.")
9294
trainer_config.trainer_kwargs.update({"enable_model_summary": False})
9395
self.data_config = data_config
@@ -153,6 +155,7 @@ def tune(
153155
cv: Optional[Union[int, Iterable, BaseCrossValidator]] = None,
154156
cv_agg_func: Optional[Callable] = np.mean,
155157
cv_kwargs: Optional[Dict] = {},
158+
return_best_model: bool = True,
156159
verbose: bool = False,
157160
progress_bar: bool = True,
158161
random_state: Optional[int] = 42,
@@ -200,6 +203,8 @@ def tune(
200203
cv_kwargs (Optional[Dict], optional): Additional keyword arguments to be passed to the cross validation
201204
method. Defaults to {}.
202205
206+
return_best_model (bool, optional): If True, will return the best model. Defaults to True.
207+
203208
verbose (bool, optional): Whether to print the results of each trial. Defaults to False.
204209
205210
progress_bar (bool, optional): Whether to show a progress bar. Defaults to True.
@@ -215,6 +220,7 @@ def tune(
215220
trials_df (DataFrame): A dataframe with the results of each trial
216221
best_params (Dict): The best parameters found
217222
best_score (float): The best score found
223+
best_model (TabularModel or None): If return_best_model is True, return best_model otherwise return None
218224
"""
219225
assert strategy in self.ALLOWABLE_STRATEGIES, f"tuner must be one of {self.ALLOWABLE_STRATEGIES}"
220226
assert mode in ["max", "min"], "mode must be one of ['max', 'min']"
@@ -270,6 +276,8 @@ def tune(
270276
metric_str = metric.__name__
271277
del temp_tabular_model
272278
trials = []
279+
best_model = None
280+
best_score = 0.0
273281
for i, params in enumerate(iterator):
274282
# Copying the configs as a base
275283
# Make sure all default parameters that you want to be set for all
@@ -334,6 +342,22 @@ def tune(
334342
else:
335343
result = tabular_model_t.evaluate(validation, verbose=False)
336344
params.update({k.replace("test_", ""): v for k, v in result[0].items()})
345+
346+
if return_best_model:
347+
tabular_model_t.datamodule = None
348+
if best_model is None:
349+
best_model = deepcopy(tabular_model_t)
350+
best_score = params[metric_str]
351+
else:
352+
if mode == "min":
353+
if params[metric_str] < best_score:
354+
best_model = deepcopy(tabular_model_t)
355+
best_score = params[metric_str]
356+
elif mode == "max":
357+
if params[metric_str] > best_score:
358+
best_model = deepcopy(tabular_model_t)
359+
best_score = params[metric_str]
360+
337361
params.update({"trial_id": i})
338362
trials.append(params)
339363
if verbose:
@@ -349,4 +373,13 @@ def tune(
349373
best_params = trials_df.iloc[best_idx].to_dict()
350374
best_score = best_params.pop(metric_str)
351375
trials_df.insert(0, "trial_id", trials)
352-
return self.OUTPUT(trials_df, best_params, best_score)
376+
377+
if verbose:
378+
logger.info("Model Tuner Finished")
379+
logger.info(f"Best Score ({metric_str}): {best_score}")
380+
381+
if return_best_model and best_model is not None:
382+
best_model.datamodule = datamodule
383+
return self.OUTPUT(trials_df, best_params, best_score, best_model)
384+
else:
385+
return self.OUTPUT(trials_df, best_params, best_score, None)

0 commit comments

Comments
 (0)