Skip to content

Commit 69abe1d

Browse files
authored
Merge branch 'main' into dependabot-github_actions-pypa-gh-action-pypi-publish-1.11.0
2 parents 36067c5 + f354b9c commit 69abe1d

File tree

6 files changed

+18
-5
lines changed

6 files changed

+18
-5
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ repos:
2727
- id: detect-private-key
2828

2929
- repo: https://github.com/PyCQA/docformatter
30-
rev: v1.7.5
30+
rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5
3131
hooks:
3232
- id: docformatter
3333
additional_dependencies: [tomli]

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,13 @@ loaded_model = TabularModel.load_model("examples/basic")
254254
<sub><b>Luca Actis Grosso</b></sub>
255255
</a>
256256
</td>
257+
<td align="center">
258+
<a href="https://github.com/snehilchatterjee">
259+
<img src="https://avatars.githubusercontent.com/u/127598707?v=4" width="100;" alt="snehilchatterjee"/>
260+
<br />
261+
<sub><b>Snehil Chatterjee</b></sub>
262+
</a>
263+
</td>
257264
<td align="center">
258265
<a href="https://github.com/sgbaird">
259266
<img src="https://avatars.githubusercontent.com/u/45469701?v=4" width="100;" alt="sgbaird"/>
@@ -275,15 +282,15 @@ loaded_model = TabularModel.load_model("examples/basic")
275282
<sub><b>Yinyu Nie</b></sub>
276283
</a>
277284
</td>
285+
</tr>
286+
<tr>
278287
<td align="center">
279288
<a href="https://github.com/YonyBresler">
280289
<img src="https://avatars.githubusercontent.com/u/24940683?v=4" width="100;" alt="YonyBresler"/>
281290
<br />
282291
<sub><b>YonyBresler</b></sub>
283292
</a>
284293
</td>
285-
</tr>
286-
<tr>
287294
<td align="center">
288295
<a href="https://github.com/HernandoR">
289296
<img src="https://avatars.githubusercontent.com/u/45709656?v=4" width="100;" alt="HernandoR"/>

requirements/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ pandas >=1.1.5
44
scikit-learn >=1.3.0
55
pytorch-lightning >=2.0.0, <2.5.0
66
omegaconf >=2.3.0
7-
torchmetrics >=0.10.0, <1.5.0
7+
torchmetrics >=0.10.0, <1.6.0
88
tensorboard >2.2.0, !=2.5.0
99
protobuf >=3.20.0, <5.29.0
1010
pytorch-tabnet ==4.1

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
wandb >=0.15.0, <0.17.0
1+
wandb >=0.15.0, <0.19.0
22
plotly>=5.13.0, <5.25.0
33
kaleido >=0.2.0, <0.3.0
44
captum >=0.5.0, <0.8.0

src/pytorch_tabular/categorical_encoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def transform(self, X):
6262
not X[self.cols].isnull().any().any()
6363
), "`handle_missing` = `error` and missing values found in columns to encode."
6464
X_encoded = X.copy(deep=True)
65+
category_cols = X_encoded.select_dtypes(include="category").columns
66+
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
6567
for col, mapping in self._mapping.items():
6668
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
6769

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig:
301301
else:
302302
raise ValueError(f"{config.task} is an unsupported task.")
303303
if self.train is not None:
304+
category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns
305+
self.train[category_cols] = self.train[category_cols].astype("object")
304306
categorical_cardinality = [
305307
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
306308
]
307309
else:
310+
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns
311+
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object")
308312
categorical_cardinality = [
309313
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
310314
]

0 commit comments

Comments
 (0)