Skip to content

Commit ba95a0d

Browse files
authored
Merge branch 'main' into pre-commit-ci-update-config
2 parents b847571 + de60f5f commit ba95a0d

File tree

13 files changed

+1917
-3
lines changed

13 files changed

+1917
-3
lines changed

.github/workflows/releasing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838

3939
- name: Publish distribution 📦 to PyPI
4040
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
41-
uses: pypa/gh-action-pypi-publish@v1.11.0
41+
uses: pypa/gh-action-pypi-publish@v1.12.2
4242
with:
4343
user: __token__
4444
password: ${{ secrets.pypi_password }}

docs/apidocs_model.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
::: pytorch_tabular.models.TabTransformerConfig
3131
options:
3232
heading_level: 3
33+
::: pytorch_tabular.models.StackingModelConfig
34+
options:
35+
heading_level: 3
3336
::: pytorch_tabular.config.ModelConfig
3437
options:
3538
heading_level: 3
@@ -66,7 +69,9 @@
6669
::: pytorch_tabular.models.TabTransformerModel
6770
options:
6871
heading_level: 3
69-
72+
::: pytorch_tabular.models.StackingModel
73+
options:
74+
heading_level: 3
7075
## Base Model Class
7176
::: pytorch_tabular.models.BaseModel
7277
options:
59.2 KB
Loading

docs/models.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,30 @@ All the parameters have beet set to recommended values from the paper. Let's loo
253253
**For a complete list of parameters refer to the API Docs**
254254
[pytorch_tabular.models.DANetConfig][]
255255

256+
## Model Stacking
257+
258+
Model stacking is an ensemble learning technique that combines multiple base models to create a more powerful predictive model. Each base model processes the input features independently, and their outputs are concatenated before making the final prediction. This allows the model to leverage different learning patterns captured by each backbone architecture. You can use it by choosing `StackingModelConfig`.
259+
260+
The following diagram shows the concept of model stacking in PyTorch Tabular.
261+
![Model Stacking](imgs/model_stacking_concept.png)
262+
263+
The following model architectures are supported for stacking:
264+
- Category Embedding Model
265+
- TabNet Model
266+
- FTTransformer Model
267+
- Gated Additive Tree Ensemble Model
268+
- DANet Model
269+
- AutoInt Model
270+
- GANDALF Model
271+
- Node Model
272+
273+
All the parameters have been set to provide flexibility while maintaining ease of use. Let's look at them:
274+
275+
- `model_configs`: List[ModelConfig]: List of configurations for each base model. Each config should be a valid PyTorch Tabular model config (e.g., NodeConfig, GANDALFConfig)
276+
277+
**For a complete list of parameters refer to the API Docs**
278+
[pytorch_tabular.models.StackingModelConfig][]
279+
256280
## Implementing New Architectures
257281

258282
PyTorch Tabular is very easy to extend and infinitely customizable. All the models that have been implemented in PyTorch Tabular inherits an Abstract Class `BaseModel` which is in fact a PyTorchLightning Model.

docs/tutorials/16-Model Stacking.ipynb

Lines changed: 1486 additions & 0 deletions
Large diffs are not rendered by default.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ nav:
2424
- SHAP, Deep LIFT and so on through Captum Integration: "tutorials/14-Explainability.ipynb"
2525
- Custom PyTorch Models:
2626
- Implementing New Supervised Architectures: "tutorials/04-Implementing New Architectures.ipynb"
27+
- Model Stacking: "tutorials/16-Model Stacking.ipynb"
2728
- Other Features:
2829
- Using Neural Categorical Embeddings in Scikit-Learn Workflows: "tutorials/03-Neural Embedding in Scikit-Learn Workflows.ipynb"
2930
- Self-Supervised Learning using Denoising Autoencoders: "tutorials/08-Self-Supervised Learning-DAE.ipynb"

requirements/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pytorch-lightning >=2.0.0, <2.5.0
66
omegaconf >=2.3.0
77
torchmetrics >=0.10.0, <1.6.0
88
tensorboard >2.2.0, !=2.5.0
9-
protobuf >=3.20.0, <5.29.0
9+
protobuf >=3.20.0, <5.30.0
1010
pytorch-tabnet ==4.1
1111
PyYAML >=5.4, <6.1.0
1212
# importlib-metadata <1,>=0.12

src/pytorch_tabular/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .gate import GatedAdditiveTreeEnsembleConfig, GatedAdditiveTreeEnsembleModel
2020
from .mixture_density import MDNConfig, MDNModel
2121
from .node import NodeConfig, NODEModel
22+
from .stacking import StackingModel, StackingModelConfig
2223
from .tab_transformer import TabTransformerConfig, TabTransformerModel
2324
from .tabnet import TabNetModel, TabNetModelConfig
2425

@@ -45,6 +46,8 @@
4546
"GANDALFBackbone",
4647
"DANetConfig",
4748
"DANetModel",
49+
"StackingModel",
50+
"StackingModelConfig",
4851
"category_embedding",
4952
"node",
5053
"mixture_density",
@@ -55,4 +58,5 @@
5558
"gate",
5659
"gandalf",
5760
"danet",
61+
"stacking",
5862
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .config import StackingModelConfig
2+
from .stacking_model import StackingBackbone, StackingModel
3+
4+
__all__ = ["StackingModel", "StackingModelConfig", "StackingBackbone"]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import dataclass, field
2+
3+
from pytorch_tabular.config import ModelConfig
4+
5+
6+
@dataclass
7+
class StackingModelConfig(ModelConfig):
8+
"""StackingModelConfig is a configuration class for the StackingModel. It is used to stack multiple models
9+
together. Now, CategoryEmbeddingModel, TabNetModel, FTTransformerModel, GatedAdditiveTreeEnsembleModel, DANetModel,
10+
AutoIntModel, GANDALFModel, NodeModel are supported.
11+
12+
Args:
13+
model_configs (list[ModelConfig]): List of model configs to stack.
14+
15+
"""
16+
17+
model_configs: list = field(default_factory=list, metadata={"help": "List of model configs to stack"})
18+
_module_src: str = field(default="models.stacking")
19+
_model_name: str = field(default="StackingModel")
20+
_backbone_name: str = field(default="StackingBackbone")
21+
_config_name: str = field(default="StackingConfig")
22+
23+
24+
# if __name__ == "__main__":
25+
# from pytorch_tabular.utils import generate_doc_dataclass
26+
# print(generate_doc_dataclass(StackingModelConfig))

0 commit comments

Comments
 (0)