Skip to content

Commit 82a30fe

Browse files
committed
-- added testcase for saving NODE
1 parent 6b892b6 commit 82a30fe

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/test_common.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
)
1616

1717
MODEL_CONFIG_SAVE_TEST = [
18-
CategoryEmbeddingModelConfig,
19-
AutoIntConfig,
20-
TabNetModelConfig,
18+
(CategoryEmbeddingModelConfig, dict(layers="10-20")),
19+
(AutoIntConfig, dict(num_heads=1,num_attn_blocks=1,)),
20+
(NodeConfig, dict(num_trees=100, depth=2)),
21+
(TabNetModelConfig, dict(n_a=2, n_d=2)),
2122
]
2223

2324
MODEL_CONFIG_FEATURE_EXT_TEST = [
@@ -67,7 +68,8 @@ def test_save_load(
6768
continuous_cols=continuous_cols,
6869
categorical_cols=categorical_cols,
6970
)
70-
model_config_params = dict(task="regression")
71+
model_config_class, model_config_params = model_config_class
72+
model_config_params['task']="regression"
7173
model_config = model_config_class(**model_config_params)
7274
trainer_config = TrainerConfig(
7375
max_epochs=3, checkpoints=None, early_stopping=None, gpus=0, fast_dev_run=True

0 commit comments

Comments
 (0)