Skip to content

Commit 5943159

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6afe604 commit 5943159

File tree

1 file changed

+32
-7
lines changed

1 file changed

+32
-7
lines changed

tests/test_common.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,16 @@ def test_tta_regression(
11431143

11441144

11451145
def _run_model_compare(
1146-
task, model_list, data_config, trainer_config, optimizer_config, train, test, metric, rank_metric, custom_fit_params={},
1146+
task,
1147+
model_list,
1148+
data_config,
1149+
trainer_config,
1150+
optimizer_config,
1151+
train,
1152+
test,
1153+
metric,
1154+
rank_metric,
1155+
custom_fit_params={},
11471156
):
11481157
model_list = copy.deepcopy(model_list)
11491158
if isinstance(model_list, list):
@@ -1249,12 +1258,16 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
12491258
# best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist()
12501259
# assert best_model.model._get_name() in best_models
12511260

1261+
12521262
@pytest.mark.parametrize("model_list", ["lite", MODEL_CONFIG_MODEL_SWEEP_TEST])
12531263
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
12541264
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
1255-
@pytest.mark.parametrize("metric", [
1265+
@pytest.mark.parametrize(
1266+
"metric",
1267+
[
12561268
(["mean_squared_error"], [{}], [False]),
1257-
])
1269+
],
1270+
)
12581271
@pytest.mark.parametrize("rank_metric", [("loss", "lower_is_better")])
12591272
@pytest.mark.parametrize(
12601273
"custom_fit_params",
@@ -1263,11 +1276,13 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
12631276
"loss": torch.nn.L1Loss(),
12641277
"metrics": [fake_metric],
12651278
"metrics_prob_inputs": [True],
1266-
"optimizer": torch.optim.Adagrad,
1279+
"optimizer": torch.optim.Adagrad,
12671280
},
1268-
]
1281+
],
12691282
)
1270-
def test_model_compare_custom(regression_data, model_list, continuous_cols, categorical_cols, metric, rank_metric, custom_fit_params):
1283+
def test_model_compare_custom(
1284+
regression_data, model_list, continuous_cols, categorical_cols, metric, rank_metric, custom_fit_params
1285+
):
12711286
(train, test, target) = regression_data
12721287
data_config = DataConfig(
12731288
target=target,
@@ -1285,7 +1300,16 @@ def test_model_compare_custom(regression_data, model_list, continuous_cols, cate
12851300
)
12861301
optimizer_config = OptimizerConfig()
12871302
comp_df, best_model = _run_model_compare(
1288-
"regression", model_list, data_config, trainer_config, optimizer_config, train, test, metric, rank_metric, custom_fit_params=custom_fit_params
1303+
"regression",
1304+
model_list,
1305+
data_config,
1306+
trainer_config,
1307+
optimizer_config,
1308+
train,
1309+
test,
1310+
metric,
1311+
rank_metric,
1312+
custom_fit_params=custom_fit_params,
12891313
)
12901314
if model_list == "lite":
12911315
assert len(comp_df) == 3
@@ -1294,6 +1318,7 @@ def test_model_compare_custom(regression_data, model_list, continuous_cols, cate
12941318
if custom_fit_params.get("metric", None) == fake_metric:
12951319
assert "test_fake_metric" in comp_df.columns()
12961320

1321+
12971322
@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST)
12981323
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
12991324
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])

0 commit comments

Comments
 (0)