@@ -71,6 +71,7 @@ def __init__(
7171 test_tensors : Optional [BaseDatasetInputType ] = None ,
7272 resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
7373 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
74+ shuffle : Optional [bool ] = True ,
7475 seed : Optional [int ] = 42 ,
7576 train_transforms : Optional [torchvision .transforms .Compose ] = None ,
7677 val_transforms : Optional [torchvision .transforms .Compose ] = None ,
@@ -91,7 +92,7 @@ def __init__(
9192 resampling_strategy_args (Optional[Dict[str, Any]]):
9293 arguments required for the chosen resampling strategy.
9394 The details are provided in autoPytorch/datasets/resampling_strategy.py
94- shuffle: Whether to shuffle the data when performing splits
95+ shuffle: Whether to shuffle the data before performing splits
9596 seed (int), (default=1): seed to be used for reproducibility.
9697 train_transforms (Optional[torchvision.transforms.Compose]):
9798 Additional Transforms to be applied to the training data
@@ -107,12 +108,14 @@ def __init__(
107108 type_check (train_tensors , val_tensors )
108109 self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
109110 self .random_state = np .random .RandomState (seed = seed )
111+ self .shuffle = shuffle
112+
110113 self .resampling_strategy = resampling_strategy
111114 self .resampling_strategy_args : Dict [str , Any ] = {}
112115 if resampling_strategy_args is not None :
113116 self .resampling_strategy_args = resampling_strategy_args
114117
115- self .shuffle = self .resampling_strategy_args .get ('shuffle' , False )
118+ self .shuffle_split = self .resampling_strategy_args .get ('shuffle' , False )
116119 self .is_stratify = self .resampling_strategy_args .get ('stratify' , False )
117120
118121 self .task_type : Optional [str ] = None
@@ -195,7 +198,7 @@ def __len__(self) -> int:
195198 return self .train_tensors [0 ].shape [0 ]
196199
197200 def _get_indices (self ) -> np .ndarray :
198- return np .arange (len (self ))
201+ return self . random_state . permutation ( len ( self )) if self . shuffle else np .arange (len (self ))
199202
200203 def _process_resampling_strategy_args (self ) -> None :
201204 if not any (isinstance (self .resampling_strategy , val_type )
@@ -238,7 +241,7 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
238241 return self .resampling_strategy (
239242 random_state = self .random_state ,
240243 val_share = val_share ,
241- shuffle = self .shuffle ,
244+ shuffle = self .shuffle_split ,
242245 indices = self ._get_indices (),
243246 labels_to_stratify = labels_to_stratify
244247 )
@@ -248,7 +251,7 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
248251 return self .resampling_strategy (
249252 random_state = self .random_state ,
250253 num_splits = num_splits ,
251- shuffle = self .shuffle ,
254+ shuffle = self .shuffle_split ,
252255 indices = self ._get_indices (),
253256 labels_to_stratify = labels_to_stratify
254257 )
0 commit comments