Skip to content

Commit caa3ea1

Browse files
Add dataloader_kwargs support in DataConfig (#492)
* Add dataloader_kwargs support in DataConfig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a890dda commit caa3ea1

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/pytorch_tabular/config/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class DataConfig:
9696
handle_missing_values (bool): Whether to handle missing values in categorical columns as
9797
unknown
9898
99+
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
100+
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
101+
99102
"""
100103

101104
target: Optional[List[str]] = field(
@@ -176,6 +179,11 @@ class DataConfig:
176179
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
177180
)
178181

182+
dataloader_kwargs: Dict[str, Any] = field(
183+
default_factory=dict,
184+
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
185+
)
186+
179187
def __post_init__(self):
180188
assert (
181189
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
805805
num_workers=self.config.num_workers,
806806
sampler=self.train_sampler,
807807
pin_memory=self.config.pin_memory,
808+
**self.config.dataloader_kwargs,
808809
)
809810

810811
def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
@@ -823,6 +824,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
823824
shuffle=False,
824825
num_workers=self.config.num_workers,
825826
pin_memory=self.config.pin_memory,
827+
**self.config.dataloader_kwargs,
826828
)
827829

828830
def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
@@ -865,6 +867,7 @@ def prepare_inference_dataloader(
865867
batch_size or self.batch_size,
866868
shuffle=False,
867869
num_workers=self.config.num_workers,
870+
**self.config.dataloader_kwargs,
868871
)
869872

870873
def save_dataloader(self, path: Union[str, Path]) -> None:

0 commit comments

Comments
 (0)