Skip to content

Commit b245ba2

Browse files
authored
Merge branch 'main' into dependabot-pip-mkdocstrings-python--eq-0.27.star
2 parents f436724 + de60f5f commit b245ba2

File tree

21 files changed

+1983
-34
lines changed

21 files changed

+1983
-34
lines changed

.github/workflows/releasing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838

3939
- name: Publish distribution 📦 to PyPI
4040
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
41-
uses: pypa/gh-action-pypi-publish@v1.11.0
41+
uses: pypa/gh-action-pypi-publish@v1.12.2
4242
with:
4343
user: __token__
4444
password: ${{ secrets.pypi_password }}

docs/apidocs_model.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
::: pytorch_tabular.models.TabTransformerConfig
3131
options:
3232
heading_level: 3
33+
::: pytorch_tabular.models.StackingModelConfig
34+
options:
35+
heading_level: 3
3336
::: pytorch_tabular.config.ModelConfig
3437
options:
3538
heading_level: 3
@@ -66,7 +69,9 @@
6669
::: pytorch_tabular.models.TabTransformerModel
6770
options:
6871
heading_level: 3
69-
72+
::: pytorch_tabular.models.StackingModel
73+
options:
74+
heading_level: 3
7075
## Base Model Class
7176
::: pytorch_tabular.models.BaseModel
7277
options:
59.2 KB
Loading

docs/models.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,30 @@ All the parameters have beet set to recommended values from the paper. Let's loo
253253
**For a complete list of parameters refer to the API Docs**
254254
[pytorch_tabular.models.DANetConfig][]
255255

256+
## Model Stacking
257+
258+
Model stacking is an ensemble learning technique that combines multiple base models to create a more powerful predictive model. Each base model processes the input features independently, and their outputs are concatenated before making the final prediction. This allows the model to leverage different learning patterns captured by each backbone architecture. You can use it by choosing `StackingModelConfig`.
259+
260+
The following diagram shows the concept of model stacking in PyTorch Tabular.
261+
![Model Stacking](imgs/model_stacking_concept.png)
262+
263+
The following model architectures are supported for stacking:
264+
- Category Embedding Model
265+
- TabNet Model
266+
- FTTransformer Model
267+
- Gated Additive Tree Ensemble Model
268+
- DANet Model
269+
- AutoInt Model
270+
- GANDALF Model
271+
- Node Model
272+
273+
All the parameters have been set to provide flexibility while maintaining ease of use. Let's look at them:
274+
275+
- `model_configs`: List[ModelConfig]: List of configurations for each base model. Each config should be a valid PyTorch Tabular model config (e.g., NodeConfig, GANDALFConfig)
276+
277+
**For a complete list of parameters refer to the API Docs**
278+
[pytorch_tabular.models.StackingModelConfig][]
279+
256280
## Implementing New Architectures
257281

258282
PyTorch Tabular is very easy to extend and infinitely customizable. All the models that have been implemented in PyTorch Tabular inherits an Abstract Class `BaseModel` which is in fact a PyTorchLightning Model.

docs/tutorials/16-Model Stacking.ipynb

Lines changed: 1486 additions & 0 deletions
Large diffs are not rendered by default.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ nav:
2424
- SHAP, Deep LIFT and so on through Captum Integration: "tutorials/14-Explainability.ipynb"
2525
- Custom PyTorch Models:
2626
- Implementing New Supervised Architectures: "tutorials/04-Implementing New Architectures.ipynb"
27+
- Model Stacking: "tutorials/16-Model Stacking.ipynb"
2728
- Other Features:
2829
- Using Neural Categorical Embeddings in Scikit-Learn Workflows: "tutorials/03-Neural Embedding in Scikit-Learn Workflows.ipynb"
2930
- Self-Supervised Learning using Denoising Autoencoders: "tutorials/08-Self-Supervised Learning-DAE.ipynb"

requirements/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pytorch-lightning >=2.0.0, <2.5.0
66
omegaconf >=2.3.0
77
torchmetrics >=0.10.0, <1.6.0
88
tensorboard >2.2.0, !=2.5.0
9-
protobuf >=3.20.0, <5.29.0
9+
protobuf >=3.20.0, <5.30.0
1010
pytorch-tabnet ==4.1
1111
PyYAML >=5.4, <6.1.0
1212
# importlib-metadata <1,>=0.12

src/pytorch_tabular/categorical_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def transform(self, X):
6868
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
6969

7070
if self.handle_unseen == "impute":
71-
X_encoded[col].fillna(self._imputed, inplace=True)
71+
X_encoded[col] = X_encoded[col].fillna(self._imputed)
7272
elif self.handle_unseen == "error":
7373
if np.unique(X_encoded[col]).shape[0] > mapping.shape[0]:
7474
raise ValueError(f"Unseen categories found in `{col}` column.")

src/pytorch_tabular/config/config.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class DataConfig:
9696
handle_missing_values (bool): Whether to handle missing values in categorical columns as
9797
unknown
9898
99+
pickle_protocol (int): pickle protocol version passed to `torch.save` for dataset caching to disk
100+
99101
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
100102
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
101103
@@ -179,6 +181,11 @@ class DataConfig:
179181
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
180182
)
181183

184+
pickle_protocol: int = field(
185+
default=2,
186+
metadata={"help": "pickle protocol version passed to `torch.save` for dataset caching to disk"},
187+
)
188+
182189
dataloader_kwargs: Dict[str, Any] = field(
183190
default_factory=dict,
184191
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
@@ -351,8 +358,8 @@ class TrainerConfig:
351358
352359
progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`.
353360
354-
precision (int): Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`..
355-
Choices are: [`32`,`16`,`64`].
361+
precision (str): Precision of the model. Defaults to `32`. See
362+
https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
356363
357364
seed (int): Seed for random number generators. Defaults to 42
358365
@@ -536,11 +543,10 @@ class TrainerConfig:
536543
default="rich",
537544
metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."},
538545
)
539-
precision: int = field(
540-
default=32,
546+
precision: str = field(
547+
default="32",
541548
metadata={
542-
"help": "Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`.",
543-
"choices": [32, 16, 64],
549+
"help": "Precision of the model. Defaults to `32`.",
544550
},
545551
)
546552
seed: int = field(

src/pytorch_tabular/feature_extractor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,21 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
7979
if k in ret_value.keys():
8080
logits_predictions[k].append(ret_value[k].detach().cpu())
8181

82+
logits_dfs = []
8283
for k, v in logits_predictions.items():
8384
v = torch.cat(v, dim=0).numpy()
8485
if v.ndim == 1:
8586
v = v.reshape(-1, 1)
86-
for i in range(v.shape[-1]):
87-
if v.shape[-1] > 1:
88-
X_encoded[f"{k}_{i}"] = v[:, i]
89-
else:
90-
X_encoded[f"{k}"] = v[:, i]
87+
if v.shape[-1] > 1:
88+
temp_df = pd.DataFrame({f"{k}_{i}": v[:, i] for i in range(v.shape[-1])})
89+
else:
90+
temp_df = pd.DataFrame({f"{k}": v[:, 0]})
91+
92+
# Append the temp DataFrame to the list
93+
logits_dfs.append(temp_df)
94+
95+
preds = pd.concat(logits_dfs, axis=1)
96+
X_encoded = pd.concat([X_encoded, preds], axis=1)
9197

9298
if self.drop_original:
9399
X_encoded.drop(columns=orig_features, inplace=True)

0 commit comments

Comments
 (0)