Skip to content

Commit 9a1f71d

Browse files
Add dataloader_kwargs support in DataConfig
1 parent a890dda commit 9a1f71d

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ class DataConfig:
9595
9696
handle_missing_values (bool): Whether to handle missing values in categorical columns as
9797
unknown
98-
98+
99+
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
100+
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
99101
"""
100102

101103
target: Optional[List[str]] = field(
@@ -175,6 +177,11 @@ class DataConfig:
175177
default=True,
176178
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
177179
)
180+
181+
dataloader_kwargs: Dict[str, Any] = field(
182+
default_factory=dict,
183+
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
184+
)
178185

179186
def __post_init__(self):
180187
assert (

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
805805
num_workers=self.config.num_workers,
806806
sampler=self.train_sampler,
807807
pin_memory=self.config.pin_memory,
808+
**self.config.dataloader_kwargs
808809
)
809810

810811
def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
@@ -823,6 +824,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
823824
shuffle=False,
824825
num_workers=self.config.num_workers,
825826
pin_memory=self.config.pin_memory,
827+
**self.config.dataloader_kwargs
826828
)
827829

828830
def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
@@ -865,6 +867,7 @@ def prepare_inference_dataloader(
865867
batch_size or self.batch_size,
866868
shuffle=False,
867869
num_workers=self.config.num_workers,
870+
**self.config.dataloader_kwargs
868871
)
869872

870873
def save_dataloader(self, path: Union[str, Path]) -> None:

0 commit comments

Comments
 (0)