@@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig:
301301 else :
302302 raise ValueError (f"{ config .task } is an unsupported task." )
303303 if self .train is not None :
304+ category_cols = self .train [config .categorical_cols ].select_dtypes (include = 'category' ).columns
305+ self .train [category_cols ] = self .train [category_cols ].astype ('object' )
304306 categorical_cardinality = [
305307 int (x ) + 1 for x in list (self .train [config .categorical_cols ].fillna ("NA" ).nunique ().values )
306308 ]
307309 else :
310+ category_cols = self .train_dataset .data [config .categorical_cols ].select_dtypes (include = 'category' ).columns
311+ self .train_dataset .data [category_cols ] = self .train_dataset .data [category_cols ].astype ('object' )
308312 categorical_cardinality = [
309313 int (x ) + 1 for x in list (self .train_dataset .data [config .categorical_cols ].nunique ().values )
310314 ]
@@ -805,6 +809,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
805809 num_workers = self .config .num_workers ,
806810 sampler = self .train_sampler ,
807811 pin_memory = self .config .pin_memory ,
812+ ** self .config .dataloader_kwargs ,
808813 )
809814
810815 def val_dataloader (self , batch_size : Optional [int ] = None ) -> DataLoader :
@@ -823,6 +828,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
823828 shuffle = False ,
824829 num_workers = self .config .num_workers ,
825830 pin_memory = self .config .pin_memory ,
831+ ** self .config .dataloader_kwargs ,
826832 )
827833
828834 def _prepare_inference_data (self , df : DataFrame ) -> DataFrame :
@@ -865,6 +871,7 @@ def prepare_inference_dataloader(
865871 batch_size or self .batch_size ,
866872 shuffle = False ,
867873 num_workers = self .config .num_workers ,
874+ ** self .config .dataloader_kwargs ,
868875 )
869876
870877 def save_dataloader (self , path : Union [str , Path ]) -> None :
0 commit comments