Skip to content

Commit be9c563

Browse files
committed
fixed some issues and added test cases
1 parent f2c2780 commit be9c563

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

src/pytorch_tabular/tabular_model.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,7 +1832,20 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str:
18321832

18331833
def __str__(self) -> str:
18341834
"""Returns a readable summary of the TabularModel object."""
1835-
return f"{self.__class__.__name__}(model={self.model.__class__.__name__ if self.has_model else 'None'})"
1835+
return f"{self.__class__.__name__}(model={self.model.__class__.__name__ if self.has_model else self.config._model_name+'(Not Initialized)'})"
1836+
1837+
def __repr__(self) -> str:
1838+
"""Returns an unambiguous representation of the TabularModel object."""
1839+
config_str = json.dumps(
1840+
OmegaConf.to_container(self.config, resolve=True), indent=4
1841+
)
1842+
ret_str = f"{self.__class__.__name__}(\n"
1843+
if self.has_model:
1844+
ret_str += f" model={self.model.__class__.__name__},\n"
1845+
else:
1846+
ret_str += f" model={self.config._model_name} (Not Initialized),\n"
1847+
ret_str += f" config={config_str},\n"
1848+
return ret_str
18361849

18371850
def _repr_html_(self):
18381851
"""Generate an HTML representation for Jupyter Notebook."""
@@ -1912,10 +1925,18 @@ def _repr_html_(self):
19121925
header_html = f"<div class='header'>{html.escape(self.model.__class__.__name__ if self.has_model else self.config._model_name)}{model_status}</div>"
19131926

19141927
# Config Section
1915-
config_html = self._generate_collapsible_section("Model Config", self.config, uid=uid, is_dict=True)
1928+
config_html = self._generate_collapsible_section(
1929+
"Model Config", self.config, uid=uid, is_dict=True
1930+
)
19161931

19171932
# Summary Section
1918-
summary_html = "" if not self.has_model else self._generate_collapsible_section("Model Summary", self._generate_model_summary_table(), uid=uid)
1933+
summary_html = (
1934+
""
1935+
if not self.has_model
1936+
else self._generate_collapsible_section(
1937+
"Model Summary", self._generate_model_summary_table(), uid=uid
1938+
)
1939+
)
19191940

19201941
# Combine sections
19211942
return f"""
@@ -1930,7 +1951,9 @@ def _repr_html_(self):
19301951
def _generate_collapsible_section(self, title, content, uid, is_dict=False):
19311952
container_id = title.lower().replace(" ", "_") + uid
19321953
if is_dict:
1933-
content = self._generate_nested_collapsible_sections(OmegaConf.to_container(content, resolve=True), container_id)
1954+
content = self._generate_nested_collapsible_sections(
1955+
OmegaConf.to_container(content, resolve=True), container_id
1956+
)
19341957
return f"""
19351958
<div>
19361959
<span class="toggle-button" onclick="toggleVisibility('{container_id}')">&#9654;</span>
@@ -1947,7 +1970,9 @@ def _generate_nested_collapsible_sections(self, content, parent_id):
19471970
if isinstance(value, dict):
19481971
nested_id = f"{parent_id}_{key}".replace(" ", "_")
19491972
nested_id = nested_id + str(uuid.uuid4())
1950-
nested_content = self._generate_nested_collapsible_sections(value, nested_id)
1973+
nested_content = self._generate_nested_collapsible_sections(
1974+
value, nested_id
1975+
)
19511976
html_content += f"""
19521977
<div>
19531978
<span class="toggle-button" onclick="toggleVisibility('{nested_id}')">&#9654;</span>

tests/test_common.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
MODEL_CONFIG_SAVE_TEST = [
3333
(CategoryEmbeddingModelConfig, {"layers": "10-20"}),
34-
(AutoIntConfig, {"num_heads": 1, "num_attn_blocks": 1}),
34+
(GANDALFConfig, {}),
3535
(NodeConfig, {"num_trees": 100, "depth": 2}),
3636
(TabNetModelConfig, {"n_a": 2, "n_d": 2}),
3737
]
@@ -1247,3 +1247,63 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
12471247
# # there may be multiple models with the same score
12481248
# best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist()
12491249
# assert best_model.model._get_name() in best_models
1250+
1251+
@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST)
1252+
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
1253+
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
1254+
@pytest.mark.parametrize("custom_metrics", [None, [fake_metric]])
1255+
@pytest.mark.parametrize("custom_loss", [None, torch.nn.L1Loss()])
1256+
@pytest.mark.parametrize("custom_optimizer", [None, torch.optim.Adagrad, "SGD", "torch_optimizer.AdaBound"])
1257+
def test_str_repr(
1258+
regression_data,
1259+
model_config_class,
1260+
continuous_cols,
1261+
categorical_cols,
1262+
custom_metrics,
1263+
custom_loss,
1264+
custom_optimizer,
1265+
):
1266+
(train, test, target) = regression_data
1267+
data_config = DataConfig(
1268+
target=target,
1269+
continuous_cols=continuous_cols,
1270+
categorical_cols=categorical_cols,
1271+
)
1272+
model_config_class, model_config_params = model_config_class
1273+
model_config_params["task"] = "regression"
1274+
model_config = model_config_class(**model_config_params)
1275+
trainer_config = TrainerConfig(
1276+
max_epochs=3,
1277+
checkpoints=None,
1278+
early_stopping=None,
1279+
accelerator="cpu",
1280+
fast_dev_run=True,
1281+
)
1282+
optimizer_config = OptimizerConfig()
1283+
1284+
tabular_model = TabularModel(
1285+
data_config=data_config,
1286+
model_config=model_config,
1287+
optimizer_config=optimizer_config,
1288+
trainer_config=trainer_config,
1289+
)
1290+
assert "Not Initialized" in str(tabular_model)
1291+
assert "Not Initialized" in repr(tabular_model)
1292+
assert "Model Summary" not in tabular_model._repr_html_()
1293+
assert "Model Config" in tabular_model._repr_html_()
1294+
assert "config" in tabular_model.__repr__()
1295+
assert "config" not in str(tabular_model)
1296+
tabular_model.fit(
1297+
train=train,
1298+
metrics=custom_metrics,
1299+
metrics_prob_inputs=None if custom_metrics is None else [False],
1300+
loss=custom_loss,
1301+
optimizer=custom_optimizer,
1302+
optimizer_params={}
1303+
)
1304+
assert model_config_class._model_name in str(tabular_model)
1305+
assert model_config_class._model_name in repr(tabular_model)
1306+
assert "Model Summary" in tabular_model._repr_html_()
1307+
assert "Model Config" in tabular_model._repr_html_()
1308+
assert "config" in tabular_model.__repr__()
1309+
assert model_config_class._model_name in tabular_model._repr_html_()

0 commit comments

Comments
 (0)