Skip to content

Commit edab3f3

Browse files
Categorical bug fix
1 parent a890dda commit edab3f3

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/pytorch_tabular/categorical_encoders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def transform(self, X):
6262
not X[self.cols].isnull().any().any()
6363
), "`handle_missing` = `error` and missing values found in columns to encode."
6464
X_encoded = X.copy(deep=True)
65+
category_cols = X_encoded.select_dtypes(include='category').columns
66+
X_encoded[category_cols] = X_encoded[category_cols].astype('object')
6567
for col, mapping in self._mapping.items():
6668
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
6769

@@ -267,4 +269,4 @@ def save_as_object_file(self, path):
267269

268270
def load_from_object_file(self, path):
269271
for k, v in pickle.load(open(path, "rb")).items():
270-
setattr(self, k, v)
272+
setattr(self, k, v)

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig:
301301
else:
302302
raise ValueError(f"{config.task} is an unsupported task.")
303303
if self.train is not None:
304+
category_cols = self.train[config.categorical_cols].select_dtypes(include='category').columns
305+
self.train[category_cols] = self.train[category_cols].astype('object')
304306
categorical_cardinality = [
305307
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
306308
]
307309
else:
310+
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include='category').columns
311+
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype('object')
308312
categorical_cardinality = [
309313
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
310314
]
@@ -805,6 +809,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
805809
num_workers=self.config.num_workers,
806810
sampler=self.train_sampler,
807811
pin_memory=self.config.pin_memory,
812+
**self.config.dataloader_kwargs,
808813
)
809814

810815
def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
@@ -823,6 +828,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
823828
shuffle=False,
824829
num_workers=self.config.num_workers,
825830
pin_memory=self.config.pin_memory,
831+
**self.config.dataloader_kwargs,
826832
)
827833

828834
def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
@@ -865,6 +871,7 @@ def prepare_inference_dataloader(
865871
batch_size or self.batch_size,
866872
shuffle=False,
867873
num_workers=self.config.num_workers,
874+
**self.config.dataloader_kwargs,
868875
)
869876

870877
def save_dataloader(self, path: Union[str, Path]) -> None:

0 commit comments

Comments
 (0)