@@ -1143,7 +1143,16 @@ def test_tta_regression(
11431143
11441144
11451145def _run_model_compare (
1146- task , model_list , data_config , trainer_config , optimizer_config , train , test , metric , rank_metric , custom_fit_params = {},
1146+ task ,
1147+ model_list ,
1148+ data_config ,
1149+ trainer_config ,
1150+ optimizer_config ,
1151+ train ,
1152+ test ,
1153+ metric ,
1154+ rank_metric ,
1155+ custom_fit_params = {},
11471156):
11481157 model_list = copy .deepcopy (model_list )
11491158 if isinstance (model_list , list ):
@@ -1249,12 +1258,16 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
12491258 # best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist()
12501259 # assert best_model.model._get_name() in best_models
12511260
1261+
12521262@pytest .mark .parametrize ("model_list" , ["lite" , MODEL_CONFIG_MODEL_SWEEP_TEST ])
12531263@pytest .mark .parametrize ("continuous_cols" , [list (DATASET_CONTINUOUS_COLUMNS )])
12541264@pytest .mark .parametrize ("categorical_cols" , [["HouseAgeBin" ]])
1255- @pytest .mark .parametrize ("metric" , [
1265+ @pytest .mark .parametrize (
1266+ "metric" ,
1267+ [
12561268 (["mean_squared_error" ], [{}], [False ]),
1257- ])
1269+ ],
1270+ )
12581271@pytest .mark .parametrize ("rank_metric" , [("loss" , "lower_is_better" )])
12591272@pytest .mark .parametrize (
12601273 "custom_fit_params" ,
@@ -1263,11 +1276,13 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
12631276 "loss" : torch .nn .L1Loss (),
12641277 "metrics" : [fake_metric ],
12651278 "metrics_prob_inputs" : [True ],
1266- "optimizer" : torch .optim .Adagrad ,
1279+ "optimizer" : torch .optim .Adagrad ,
12671280 },
1268- ]
1281+ ],
12691282)
1270- def test_model_compare_custom (regression_data , model_list , continuous_cols , categorical_cols , metric , rank_metric , custom_fit_params ):
1283+ def test_model_compare_custom (
1284+ regression_data , model_list , continuous_cols , categorical_cols , metric , rank_metric , custom_fit_params
1285+ ):
12711286 (train , test , target ) = regression_data
12721287 data_config = DataConfig (
12731288 target = target ,
@@ -1285,7 +1300,16 @@ def test_model_compare_custom(regression_data, model_list, continuous_cols, cate
12851300 )
12861301 optimizer_config = OptimizerConfig ()
12871302 comp_df , best_model = _run_model_compare (
1288- "regression" , model_list , data_config , trainer_config , optimizer_config , train , test , metric , rank_metric , custom_fit_params = custom_fit_params
1303+ "regression" ,
1304+ model_list ,
1305+ data_config ,
1306+ trainer_config ,
1307+ optimizer_config ,
1308+ train ,
1309+ test ,
1310+ metric ,
1311+ rank_metric ,
1312+ custom_fit_params = custom_fit_params ,
12891313 )
12901314 if model_list == "lite" :
12911315 assert len (comp_df ) == 3
@@ -1294,6 +1318,7 @@ def test_model_compare_custom(regression_data, model_list, continuous_cols, cate
12941318 if custom_fit_params .get ("metric" , None ) == fake_metric :
12951319 assert "test_fake_metric" in comp_df .columns ()
12961320
1321+
12971322@pytest .mark .parametrize ("model_config_class" , MODEL_CONFIG_SAVE_TEST )
12981323@pytest .mark .parametrize ("continuous_cols" , [list (DATASET_CONTINUOUS_COLUMNS )])
12991324@pytest .mark .parametrize ("categorical_cols" , [["HouseAgeBin" ]])
0 commit comments