Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/pytorch_tabular/tabular_model_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _validate_args(
experiment_config: Optional[Union[ExperimentConfig, str]] = None,
common_model_args: Optional[dict] = {},
rank_metric: Optional[str] = "loss",
custom_fit_params: Optional[dict] = {},
):
assert task in [
"classification",
Expand Down Expand Up @@ -149,6 +150,8 @@ def _validate_args(
"lower_is_better",
"higher_is_better",
], "rank_metric[1] must be one of ['lower_is_better', 'higher_is_better'], but" f" got {rank_metric[1]}"
if "metrics" in custom_fit_params.keys():
assert rank_metric[0] == "loss", "only loss is supported as the rank_metric when using custom metrics"


def model_sweep(
Expand All @@ -172,6 +175,7 @@ def model_sweep(
progress_bar: bool = True,
verbose: bool = True,
suppress_lightning_logger: bool = True,
custom_fit_params: Optional[dict] = {},
):
"""Compare multiple models on the same dataset.

Expand Down Expand Up @@ -231,6 +235,10 @@ def model_sweep(

suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True.

custom_fit_params (dict, optional): A dict specifying custom loss, metrics and optimizer.
The behviour of these custom parameters is similar to those passed through the `fit` method
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in documentation: "behviour" should be "behaviour".

Suggested change
The behviour of these custom parameters is similar to those passed through the `fit` method
The behaviour of these custom parameters is similar to those passed through the `fit` method

Copilot uses AI. Check for mistakes.
of `TabularModel`.

Returns:
results: Training results.

Expand All @@ -252,6 +260,7 @@ def model_sweep(
experiment_config=experiment_config,
common_model_args=common_model_args,
rank_metric=rank_metric,
custom_fit_params=custom_fit_params,
)
if suppress_lightning_logger:
suppress_lightning_logs()
Expand Down Expand Up @@ -326,7 +335,7 @@ def _init_tabular_model(m):
name = tabular_model.name
if verbose:
logger.info(f"Training {name}")
model = tabular_model.prepare_model(datamodule)
model = tabular_model.prepare_model(datamodule, **custom_fit_params)
if progress_bar:
progress.update(task_p, description=f"Training {name}", advance=1)
with OutOfMemoryHandler(handle_oom=True) as handler:
Expand Down
72 changes: 71 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,16 @@ def test_tta_regression(


def _run_model_compare(
task, model_list, data_config, trainer_config, optimizer_config, train, test, metric, rank_metric
task,
model_list,
data_config,
trainer_config,
optimizer_config,
train,
test,
metric,
rank_metric,
custom_fit_params={},
):
model_list = copy.deepcopy(model_list)
if isinstance(model_list, list):
Expand All @@ -1161,6 +1170,7 @@ def _run_model_compare(
metrics_params=metric[1],
metrics_prob_input=metric[2],
rank_metric=rank_metric,
custom_fit_params=custom_fit_params,
)


Expand Down Expand Up @@ -1249,6 +1259,66 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
# assert best_model.model._get_name() in best_models


@pytest.mark.parametrize("model_list", ["lite", MODEL_CONFIG_MODEL_SWEEP_TEST])
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
@pytest.mark.parametrize(
"metric",
[
(["mean_squared_error"], [{}], [False]),
],
)
@pytest.mark.parametrize("rank_metric", [("loss", "lower_is_better")])
@pytest.mark.parametrize(
"custom_fit_params",
[
{
"loss": torch.nn.L1Loss(),
"metrics": [fake_metric],
"metrics_prob_inputs": [True],
"optimizer": torch.optim.Adagrad,
},
],
)
def test_model_compare_custom(
regression_data, model_list, continuous_cols, categorical_cols, metric, rank_metric, custom_fit_params
):
(train, test, target) = regression_data
data_config = DataConfig(
target=target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
handle_missing_values=True,
handle_unknown_categories=True,
)
trainer_config = TrainerConfig(
max_epochs=3,
checkpoints=None,
early_stopping=None,
accelerator="cpu",
fast_dev_run=True,
)
optimizer_config = OptimizerConfig()
comp_df, best_model = _run_model_compare(
"regression",
model_list,
data_config,
trainer_config,
optimizer_config,
train,
test,
metric,
rank_metric,
custom_fit_params=custom_fit_params,
)
if model_list == "lite":
assert len(comp_df) == 3
else:
assert len(comp_df) == len(model_list)
if custom_fit_params.get("metric", None) == fake_metric:
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug in test assertion: should check "metrics" (plural) instead of "metric" (singular). The custom_fit_params dictionary uses the key "metrics" (line 1277), so this condition will never be true, making this assertion ineffective.

Suggested change
if custom_fit_params.get("metric", None) == fake_metric:
if fake_metric in custom_fit_params.get("metrics", []):

Copilot uses AI. Check for mistakes.
assert "test_fake_metric" in comp_df.columns()
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: columns() is being called as a method, but pandas DataFrame's columns is a property, not a method. This should be comp_df.columns instead of comp_df.columns().

Suggested change
assert "test_fake_metric" in comp_df.columns()
assert "test_fake_metric" in comp_df.columns

Copilot uses AI. Check for mistakes.


@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST)
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
Expand Down
Loading