|
31 | 31 |
|
32 | 32 | MODEL_CONFIG_SAVE_TEST = [ |
33 | 33 | (CategoryEmbeddingModelConfig, {"layers": "10-20"}), |
34 | | - (AutoIntConfig, {"num_heads": 1, "num_attn_blocks": 1}), |
| 34 | + (GANDALFConfig, {}), |
35 | 35 | (NodeConfig, {"num_trees": 100, "depth": 2}), |
36 | 36 | (TabNetModelConfig, {"n_a": 2, "n_d": 2}), |
37 | 37 | ] |
@@ -1247,3 +1247,63 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols, |
1247 | 1247 | # # there may be multiple models with the same score |
1248 | 1248 | # best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist() |
1249 | 1249 | # assert best_model.model._get_name() in best_models |
| 1250 | + |
| 1251 | +@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST) |
| 1252 | +@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)]) |
| 1253 | +@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]]) |
| 1254 | +@pytest.mark.parametrize("custom_metrics", [None, [fake_metric]]) |
| 1255 | +@pytest.mark.parametrize("custom_loss", [None, torch.nn.L1Loss()]) |
| 1256 | +@pytest.mark.parametrize("custom_optimizer", [None, torch.optim.Adagrad, "SGD", "torch_optimizer.AdaBound"]) |
| 1257 | +def test_str_repr( |
| 1258 | + regression_data, |
| 1259 | + model_config_class, |
| 1260 | + continuous_cols, |
| 1261 | + categorical_cols, |
| 1262 | + custom_metrics, |
| 1263 | + custom_loss, |
| 1264 | + custom_optimizer, |
| 1265 | +): |
| 1266 | + (train, test, target) = regression_data |
| 1267 | + data_config = DataConfig( |
| 1268 | + target=target, |
| 1269 | + continuous_cols=continuous_cols, |
| 1270 | + categorical_cols=categorical_cols, |
| 1271 | + ) |
| 1272 | + model_config_class, model_config_params = model_config_class |
| 1273 | + model_config_params["task"] = "regression" |
| 1274 | + model_config = model_config_class(**model_config_params) |
| 1275 | + trainer_config = TrainerConfig( |
| 1276 | + max_epochs=3, |
| 1277 | + checkpoints=None, |
| 1278 | + early_stopping=None, |
| 1279 | + accelerator="cpu", |
| 1280 | + fast_dev_run=True, |
| 1281 | + ) |
| 1282 | + optimizer_config = OptimizerConfig() |
| 1283 | + |
| 1284 | + tabular_model = TabularModel( |
| 1285 | + data_config=data_config, |
| 1286 | + model_config=model_config, |
| 1287 | + optimizer_config=optimizer_config, |
| 1288 | + trainer_config=trainer_config, |
| 1289 | + ) |
| 1290 | + assert "Not Initialized" in str(tabular_model) |
| 1291 | + assert "Not Initialized" in repr(tabular_model) |
| 1292 | + assert "Model Summary" not in tabular_model._repr_html_() |
| 1293 | + assert "Model Config" in tabular_model._repr_html_() |
| 1294 | + assert "config" in tabular_model.__repr__() |
| 1295 | + assert "config" not in str(tabular_model) |
| 1296 | + tabular_model.fit( |
| 1297 | + train=train, |
| 1298 | + metrics=custom_metrics, |
| 1299 | + metrics_prob_inputs=None if custom_metrics is None else [False], |
| 1300 | + loss=custom_loss, |
| 1301 | + optimizer=custom_optimizer, |
| 1302 | + optimizer_params={} |
| 1303 | + ) |
| 1304 | + assert model_config_class._model_name in str(tabular_model) |
| 1305 | + assert model_config_class._model_name in repr(tabular_model) |
| 1306 | + assert "Model Summary" in tabular_model._repr_html_() |
| 1307 | + assert "Model Config" in tabular_model._repr_html_() |
| 1308 | + assert "config" in tabular_model.__repr__() |
| 1309 | + assert model_config_class._model_name in tabular_model._repr_html_() |
0 commit comments