Skip to content

Commit a34c6f6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9a1f71d commit a34c6f6

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ class DataConfig:
9595
9696
handle_missing_values (bool): Whether to handle missing values in categorical columns as
9797
unknown
98-
98+
9999
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
100100
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
101+
101102
"""
102103

103104
target: Optional[List[str]] = field(
@@ -177,7 +178,7 @@ class DataConfig:
177178
default=True,
178179
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
179180
)
180-
181+
181182
dataloader_kwargs: Dict[str, Any] = field(
182183
default_factory=dict,
183184
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
805805
num_workers=self.config.num_workers,
806806
sampler=self.train_sampler,
807807
pin_memory=self.config.pin_memory,
808-
**self.config.dataloader_kwargs
808+
**self.config.dataloader_kwargs,
809809
)
810810

811811
def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
@@ -824,7 +824,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
824824
shuffle=False,
825825
num_workers=self.config.num_workers,
826826
pin_memory=self.config.pin_memory,
827-
**self.config.dataloader_kwargs
827+
**self.config.dataloader_kwargs,
828828
)
829829

830830
def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
@@ -867,7 +867,7 @@ def prepare_inference_dataloader(
867867
batch_size or self.batch_size,
868868
shuffle=False,
869869
num_workers=self.config.num_workers,
870-
**self.config.dataloader_kwargs
870+
**self.config.dataloader_kwargs,
871871
)
872872

873873
def save_dataloader(self, path: Union[str, Path]) -> None:

0 commit comments

Comments
 (0)