Skip to content

Commit 5fc25d6

Browse files
authored
Merge branch 'main' into dependabot-pip-numpy-gt-1.20.0-and-lt-3.0
2 parents e7c033f + f354b9c commit 5fc25d6

File tree

7 files changed

+170
-132
lines changed

7 files changed

+170
-132
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ repos:
2727
- id: detect-private-key
2828

2929
- repo: https://github.com/PyCQA/docformatter
30-
rev: v1.7.5
30+
rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5
3131
hooks:
3232
- id: docformatter
3333
additional_dependencies: [tomli]

README.md

Lines changed: 150 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -150,135 +150,156 @@ loaded_model = TabularModel.load_model("examples/basic")
150150

151151
<!-- readme: contributors -start -->
152152
<table>
153-
<tr>
154-
<td align="center">
155-
<a href="https://github.com/manujosephv">
156-
<img src="https://avatars.githubusercontent.com/u/10508493?v=4" width="100;" alt="manujosephv"/>
157-
<br />
158-
<sub><b>Manu Joseph</b></sub>
159-
</a>
160-
</td>
161-
<td align="center">
162-
<a href="https://github.com/Borda">
163-
<img src="https://avatars.githubusercontent.com/u/6035284?v=4" width="100;" alt="Borda"/>
164-
<br />
165-
<sub><b>Jirka Borovec</b></sub>
166-
</a>
167-
</td>
168-
<td align="center">
169-
<a href="https://github.com/wsad1">
170-
<img src="https://avatars.githubusercontent.com/u/13963626?v=4" width="100;" alt="wsad1"/>
171-
<br />
172-
<sub><b>Jinu Sunil</b></sub>
173-
</a>
174-
</td>
175-
<td align="center">
176-
<a href="https://github.com/ProgramadorArtificial">
177-
<img src="https://avatars.githubusercontent.com/u/130674366?v=4" width="100;" alt="ProgramadorArtificial"/>
178-
<br />
179-
<sub><b>Programador Artificial</b></sub>
180-
</a>
181-
</td>
182-
<td align="center">
183-
<a href="https://github.com/sorenmacbeth">
184-
<img src="https://avatars.githubusercontent.com/u/130043?v=4" width="100;" alt="sorenmacbeth"/>
185-
<br />
186-
<sub><b>Soren Macbeth</b></sub>
187-
</a>
188-
</td>
189-
<td align="center">
190-
<a href="https://github.com/fonnesbeck">
191-
<img src="https://avatars.githubusercontent.com/u/81476?v=4" width="100;" alt="fonnesbeck"/>
192-
<br />
193-
<sub><b>Chris Fonnesbeck</b></sub>
194-
</a>
195-
</td></tr>
196-
<tr>
197-
<td align="center">
198-
<a href="https://github.com/jxtrbtk">
199-
<img src="https://avatars.githubusercontent.com/u/40494970?v=4" width="100;" alt="jxtrbtk"/>
200-
<br />
201-
<sub><b>Null</b></sub>
202-
</a>
203-
</td>
204-
<td align="center">
205-
<a href="https://github.com/abhisharsinha">
206-
<img src="https://avatars.githubusercontent.com/u/24841841?v=4" width="100;" alt="abhisharsinha"/>
207-
<br />
208-
<sub><b>Abhishar Sinha</b></sub>
209-
</a>
210-
</td>
211-
<td align="center">
212-
<a href="https://github.com/ndrsfel">
213-
<img src="https://avatars.githubusercontent.com/u/21068727?v=4" width="100;" alt="ndrsfel"/>
214-
<br />
215-
<sub><b>Andreas</b></sub>
216-
</a>
217-
</td>
218-
<td align="center">
219-
<a href="https://github.com/charitarthchugh">
220-
<img src="https://avatars.githubusercontent.com/u/37895518?v=4" width="100;" alt="charitarthchugh"/>
221-
<br />
222-
<sub><b>Charitarth Chugh</b></sub>
223-
</a>
224-
</td>
225-
<td align="center">
226-
<a href="https://github.com/EeyoreLee">
227-
<img src="https://avatars.githubusercontent.com/u/49790022?v=4" width="100;" alt="EeyoreLee"/>
228-
<br />
229-
<sub><b>Earlee</b></sub>
230-
</a>
231-
</td>
232-
<td align="center">
233-
<a href="https://github.com/JulianRein">
234-
<img src="https://avatars.githubusercontent.com/u/35046938?v=4" width="100;" alt="JulianRein"/>
235-
<br />
236-
<sub><b>Null</b></sub>
237-
</a>
238-
</td></tr>
239-
<tr>
240-
<td align="center">
241-
<a href="https://github.com/krshrimali">
242-
<img src="https://avatars.githubusercontent.com/u/19997320?v=4" width="100;" alt="krshrimali"/>
243-
<br />
244-
<sub><b>Kushashwa Ravi Shrimali</b></sub>
245-
</a>
246-
</td>
247-
<td align="center">
248-
<a href="https://github.com/Actis92">
249-
<img src="https://avatars.githubusercontent.com/u/46601193?v=4" width="100;" alt="Actis92"/>
250-
<br />
251-
<sub><b>Luca Actis Grosso</b></sub>
252-
</a>
253-
</td>
254-
<td align="center">
255-
<a href="https://github.com/sgbaird">
256-
<img src="https://avatars.githubusercontent.com/u/45469701?v=4" width="100;" alt="sgbaird"/>
257-
<br />
258-
<sub><b>Sterling G. Baird</b></sub>
259-
</a>
260-
</td>
261-
<td align="center">
262-
<a href="https://github.com/furyhawk">
263-
<img src="https://avatars.githubusercontent.com/u/831682?v=4" width="100;" alt="furyhawk"/>
264-
<br />
265-
<sub><b>Teck Meng</b></sub>
266-
</a>
267-
</td>
268-
<td align="center">
269-
<a href="https://github.com/yinyunie">
270-
<img src="https://avatars.githubusercontent.com/u/25686434?v=4" width="100;" alt="yinyunie"/>
271-
<br />
272-
<sub><b>Yinyu Nie</b></sub>
273-
</a>
274-
</td>
275-
<td align="center">
276-
<a href="https://github.com/HernandoR">
277-
<img src="https://avatars.githubusercontent.com/u/45709656?v=4" width="100;" alt="HernandoR"/>
278-
<br />
279-
<sub><b>Liu Zhen</b></sub>
280-
</a>
281-
</td></tr>
153+
<tbody>
154+
<tr>
155+
<td align="center">
156+
<a href="https://github.com/manujosephv">
157+
<img src="https://avatars.githubusercontent.com/u/10508493?v=4" width="100;" alt="manujosephv"/>
158+
<br />
159+
<sub><b>Manu Joseph</b></sub>
160+
</a>
161+
</td>
162+
<td align="center">
163+
<a href="https://github.com/Borda">
164+
<img src="https://avatars.githubusercontent.com/u/6035284?v=4" width="100;" alt="Borda"/>
165+
<br />
166+
<sub><b>Jirka Borovec</b></sub>
167+
</a>
168+
</td>
169+
<td align="center">
170+
<a href="https://github.com/wsad1">
171+
<img src="https://avatars.githubusercontent.com/u/13963626?v=4" width="100;" alt="wsad1"/>
172+
<br />
173+
<sub><b>Jinu Sunil</b></sub>
174+
</a>
175+
</td>
176+
<td align="center">
177+
<a href="https://github.com/ProgramadorArtificial">
178+
<img src="https://avatars.githubusercontent.com/u/130674366?v=4" width="100;" alt="ProgramadorArtificial"/>
179+
<br />
180+
<sub><b>Programador Artificial</b></sub>
181+
</a>
182+
</td>
183+
<td align="center">
184+
<a href="https://github.com/sorenmacbeth">
185+
<img src="https://avatars.githubusercontent.com/u/130043?v=4" width="100;" alt="sorenmacbeth"/>
186+
<br />
187+
<sub><b>Soren Macbeth</b></sub>
188+
</a>
189+
</td>
190+
<td align="center">
191+
<a href="https://github.com/fonnesbeck">
192+
<img src="https://avatars.githubusercontent.com/u/81476?v=4" width="100;" alt="fonnesbeck"/>
193+
<br />
194+
<sub><b>Chris Fonnesbeck</b></sub>
195+
</a>
196+
</td>
197+
</tr>
198+
<tr>
199+
<td align="center">
200+
<a href="https://github.com/jxtrbtk">
201+
<img src="https://avatars.githubusercontent.com/u/40494970?v=4" width="100;" alt="jxtrbtk"/>
202+
<br />
203+
<sub><b>Null</b></sub>
204+
</a>
205+
</td>
206+
<td align="center">
207+
<a href="https://github.com/abhisharsinha">
208+
<img src="https://avatars.githubusercontent.com/u/24841841?v=4" width="100;" alt="abhisharsinha"/>
209+
<br />
210+
<sub><b>Abhishar Sinha</b></sub>
211+
</a>
212+
</td>
213+
<td align="center">
214+
<a href="https://github.com/ndrsfel">
215+
<img src="https://avatars.githubusercontent.com/u/21068727?v=4" width="100;" alt="ndrsfel"/>
216+
<br />
217+
<sub><b>Andreas</b></sub>
218+
</a>
219+
</td>
220+
<td align="center">
221+
<a href="https://github.com/charitarthchugh">
222+
<img src="https://avatars.githubusercontent.com/u/37895518?v=4" width="100;" alt="charitarthchugh"/>
223+
<br />
224+
<sub><b>Charitarth Chugh</b></sub>
225+
</a>
226+
</td>
227+
<td align="center">
228+
<a href="https://github.com/EeyoreLee">
229+
<img src="https://avatars.githubusercontent.com/u/49790022?v=4" width="100;" alt="EeyoreLee"/>
230+
<br />
231+
<sub><b>Earlee</b></sub>
232+
</a>
233+
</td>
234+
<td align="center">
235+
<a href="https://github.com/JulianRein">
236+
<img src="https://avatars.githubusercontent.com/u/35046938?v=4" width="100;" alt="JulianRein"/>
237+
<br />
238+
<sub><b>Null</b></sub>
239+
</a>
240+
</td>
241+
</tr>
242+
<tr>
243+
<td align="center">
244+
<a href="https://github.com/krshrimali">
245+
<img src="https://avatars.githubusercontent.com/u/19997320?v=4" width="100;" alt="krshrimali"/>
246+
<br />
247+
<sub><b>Kushashwa Ravi Shrimali</b></sub>
248+
</a>
249+
</td>
250+
<td align="center">
251+
<a href="https://github.com/Actis92">
252+
<img src="https://avatars.githubusercontent.com/u/46601193?v=4" width="100;" alt="Actis92"/>
253+
<br />
254+
<sub><b>Luca Actis Grosso</b></sub>
255+
</a>
256+
</td>
257+
<td align="center">
258+
<a href="https://github.com/snehilchatterjee">
259+
<img src="https://avatars.githubusercontent.com/u/127598707?v=4" width="100;" alt="snehilchatterjee"/>
260+
<br />
261+
<sub><b>Snehil Chatterjee</b></sub>
262+
</a>
263+
</td>
264+
<td align="center">
265+
<a href="https://github.com/sgbaird">
266+
<img src="https://avatars.githubusercontent.com/u/45469701?v=4" width="100;" alt="sgbaird"/>
267+
<br />
268+
<sub><b>Sterling G. Baird</b></sub>
269+
</a>
270+
</td>
271+
<td align="center">
272+
<a href="https://github.com/furyhawk">
273+
<img src="https://avatars.githubusercontent.com/u/831682?v=4" width="100;" alt="furyhawk"/>
274+
<br />
275+
<sub><b>Teck Meng</b></sub>
276+
</a>
277+
</td>
278+
<td align="center">
279+
<a href="https://github.com/yinyunie">
280+
<img src="https://avatars.githubusercontent.com/u/25686434?v=4" width="100;" alt="yinyunie"/>
281+
<br />
282+
<sub><b>Yinyu Nie</b></sub>
283+
</a>
284+
</td>
285+
</tr>
286+
<tr>
287+
<td align="center">
288+
<a href="https://github.com/YonyBresler">
289+
<img src="https://avatars.githubusercontent.com/u/24940683?v=4" width="100;" alt="YonyBresler"/>
290+
<br />
291+
<sub><b>YonyBresler</b></sub>
292+
</a>
293+
</td>
294+
<td align="center">
295+
<a href="https://github.com/HernandoR">
296+
<img src="https://avatars.githubusercontent.com/u/45709656?v=4" width="100;" alt="HernandoR"/>
297+
<br />
298+
<sub><b>Liu Zhen</b></sub>
299+
</a>
300+
</td>
301+
</tr>
302+
<tbody>
282303
</table>
283304
<!-- readme: contributors -end -->
284305

requirements/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ pandas >=1.1.5
44
scikit-learn >=1.3.0
55
pytorch-lightning >=2.0.0, <2.5.0
66
omegaconf >=2.3.0
7-
torchmetrics >=0.10.0, <1.5.0
7+
torchmetrics >=0.10.0, <1.6.0
88
tensorboard >2.2.0, !=2.5.0
99
protobuf >=3.20.0, <5.29.0
1010
pytorch-tabnet ==4.1

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
wandb >=0.15.0, <0.17.0
1+
wandb >=0.15.0, <0.19.0
22
plotly>=5.13.0, <5.25.0
33
kaleido >=0.2.0, <0.3.0
44
captum >=0.5.0, <0.8.0

src/pytorch_tabular/categorical_encoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def transform(self, X):
6262
not X[self.cols].isnull().any().any()
6363
), "`handle_missing` = `error` and missing values found in columns to encode."
6464
X_encoded = X.copy(deep=True)
65+
category_cols = X_encoded.select_dtypes(include="category").columns
66+
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
6567
for col, mapping in self._mapping.items():
6668
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
6769

src/pytorch_tabular/config/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class DataConfig:
9696
handle_missing_values (bool): Whether to handle missing values in categorical columns as
9797
unknown
9898
99+
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
100+
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
101+
99102
"""
100103

101104
target: Optional[List[str]] = field(
@@ -176,6 +179,11 @@ class DataConfig:
176179
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
177180
)
178181

182+
dataloader_kwargs: Dict[str, Any] = field(
183+
default_factory=dict,
184+
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
185+
)
186+
179187
def __post_init__(self):
180188
assert (
181189
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig:
301301
else:
302302
raise ValueError(f"{config.task} is an unsupported task.")
303303
if self.train is not None:
304+
category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns
305+
self.train[category_cols] = self.train[category_cols].astype("object")
304306
categorical_cardinality = [
305307
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
306308
]
307309
else:
310+
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns
311+
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object")
308312
categorical_cardinality = [
309313
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
310314
]
@@ -805,6 +809,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
805809
num_workers=self.config.num_workers,
806810
sampler=self.train_sampler,
807811
pin_memory=self.config.pin_memory,
812+
**self.config.dataloader_kwargs,
808813
)
809814

810815
def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
@@ -823,6 +828,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
823828
shuffle=False,
824829
num_workers=self.config.num_workers,
825830
pin_memory=self.config.pin_memory,
831+
**self.config.dataloader_kwargs,
826832
)
827833

828834
def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
@@ -865,6 +871,7 @@ def prepare_inference_dataloader(
865871
batch_size or self.batch_size,
866872
shuffle=False,
867873
num_workers=self.config.num_workers,
874+
**self.config.dataloader_kwargs,
868875
)
869876

870877
def save_dataloader(self, path: Union[str, Path]) -> None:

0 commit comments

Comments
 (0)