@@ -35,7 +35,7 @@ class TabularModelTuner:
3535 """
3636
3737 ALLOWABLE_STRATEGIES = ["grid_search" , "random_search" ]
38- OUTPUT = namedtuple ("OUTPUT" , ["trials_df" , "best_params" , "best_score" ])
38+ OUTPUT = namedtuple ("OUTPUT" , ["trials_df" , "best_params" , "best_score" , "best_model" ])
3939
4040 def __init__ (
4141 self ,
@@ -88,6 +88,8 @@ def __init__(
8888 if trainer_config .fast_dev_run :
8989 warnings .warn ("fast_dev_run is turned on. Tuning results won't be accurate." )
9090 if trainer_config .progress_bar != "none" :
91+ # If config and tuner have progress bar enabled, it will result in a bug within the library (rich.progress)
92+ trainer_config .progress_bar = "none"
9193 warnings .warn ("Turning off progress bar. Set progress_bar='none' in TrainerConfig to disable this warning." )
9294 trainer_config .trainer_kwargs .update ({"enable_model_summary" : False })
9395 self .data_config = data_config
@@ -153,6 +155,7 @@ def tune(
153155 cv : Optional [Union [int , Iterable , BaseCrossValidator ]] = None ,
154156 cv_agg_func : Optional [Callable ] = np .mean ,
155157 cv_kwargs : Optional [Dict ] = {},
158+ return_best_model : bool = True ,
156159 verbose : bool = False ,
157160 progress_bar : bool = True ,
158161 random_state : Optional [int ] = 42 ,
@@ -200,6 +203,8 @@ def tune(
200203 cv_kwargs (Optional[Dict], optional): Additional keyword arguments to be passed to the cross validation
201204 method. Defaults to {}.
202205
206+ return_best_model (bool, optional): If True, will return the best model. Defaults to True.
207+
203208 verbose (bool, optional): Whether to print the results of each trial. Defaults to False.
204209
205210 progress_bar (bool, optional): Whether to show a progress bar. Defaults to True.
@@ -215,6 +220,7 @@ def tune(
215220 trials_df (DataFrame): A dataframe with the results of each trial
216221 best_params (Dict): The best parameters found
217222 best_score (float): The best score found
223+ best_model (TabularModel or None): If return_best_model is True, return best_model otherwise return None
218224 """
219225 assert strategy in self .ALLOWABLE_STRATEGIES , f"tuner must be one of { self .ALLOWABLE_STRATEGIES } "
220226 assert mode in ["max" , "min" ], "mode must be one of ['max', 'min']"
@@ -270,6 +276,8 @@ def tune(
270276 metric_str = metric .__name__
271277 del temp_tabular_model
272278 trials = []
279+ best_model = None
280+ best_score = 0.0
273281 for i , params in enumerate (iterator ):
274282 # Copying the configs as a base
275283 # Make sure all default parameters that you want to be set for all
@@ -334,6 +342,22 @@ def tune(
334342 else :
335343 result = tabular_model_t .evaluate (validation , verbose = False )
336344 params .update ({k .replace ("test_" , "" ): v for k , v in result [0 ].items ()})
345+
346+ if return_best_model :
347+ tabular_model_t .datamodule = None
348+ if best_model is None :
349+ best_model = deepcopy (tabular_model_t )
350+ best_score = params [metric_str ]
351+ else :
352+ if mode == "min" :
353+ if params [metric_str ] < best_score :
354+ best_model = deepcopy (tabular_model_t )
355+ best_score = params [metric_str ]
356+ elif mode == "max" :
357+ if params [metric_str ] > best_score :
358+ best_model = deepcopy (tabular_model_t )
359+ best_score = params [metric_str ]
360+
337361 params .update ({"trial_id" : i })
338362 trials .append (params )
339363 if verbose :
@@ -349,4 +373,13 @@ def tune(
349373 best_params = trials_df .iloc [best_idx ].to_dict ()
350374 best_score = best_params .pop (metric_str )
351375 trials_df .insert (0 , "trial_id" , trials )
352- return self .OUTPUT (trials_df , best_params , best_score )
376+
377+ if verbose :
378+ logger .info ("Model Tuner Finished" )
379+ logger .info (f"Best Score ({ metric_str } ): { best_score } " )
380+
381+ if return_best_model and best_model is not None :
382+ best_model .datamodule = datamodule
383+ return self .OUTPUT (trials_df , best_params , best_score , best_model )
384+ else :
385+ return self .OUTPUT (trials_df , best_params , best_score , None )
0 commit comments