Skip to content

Commit f8f6c2c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent edab3f3 commit f8f6c2c

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/pytorch_tabular/categorical_encoders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def transform(self, X):
6262
not X[self.cols].isnull().any().any()
6363
), "`handle_missing` = `error` and missing values found in columns to encode."
6464
X_encoded = X.copy(deep=True)
65-
category_cols = X_encoded.select_dtypes(include='category').columns
66-
X_encoded[category_cols] = X_encoded[category_cols].astype('object')
65+
category_cols = X_encoded.select_dtypes(include="category").columns
66+
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
6767
for col, mapping in self._mapping.items():
6868
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])
6969

@@ -269,4 +269,4 @@ def save_as_object_file(self, path):
269269

270270
def load_from_object_file(self, path):
271271
for k, v in pickle.load(open(path, "rb")).items():
272-
setattr(self, k, v)
272+
setattr(self, k, v)

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,14 @@ def _update_config(self, config) -> InferredConfig:
301301
else:
302302
raise ValueError(f"{config.task} is an unsupported task.")
303303
if self.train is not None:
304-
category_cols = self.train[config.categorical_cols].select_dtypes(include='category').columns
305-
self.train[category_cols] = self.train[category_cols].astype('object')
304+
category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns
305+
self.train[category_cols] = self.train[category_cols].astype("object")
306306
categorical_cardinality = [
307307
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
308308
]
309309
else:
310-
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include='category').columns
311-
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype('object')
310+
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns
311+
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object")
312312
categorical_cardinality = [
313313
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
314314
]

0 commit comments

Comments
 (0)