Skip to content

Commit 8c9b895

Browse files
committed
[fix] Bring back the data generator shuffle
1 parent 910e7d4 commit 8c9b895

File tree

4 files changed

+18
-14
lines changed

4 files changed

+18
-14
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
test_tensors: Optional[BaseDatasetInputType] = None,
7272
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
7373
resampling_strategy_args: Optional[Dict[str, Any]] = None,
74+
shuffle: Optional[bool] = True,
7475
seed: Optional[int] = 42,
7576
train_transforms: Optional[torchvision.transforms.Compose] = None,
7677
val_transforms: Optional[torchvision.transforms.Compose] = None,
@@ -91,7 +92,7 @@ def __init__(
9192
resampling_strategy_args (Optional[Dict[str, Any]]):
9293
arguments required for the chosen resampling strategy.
9394
The details are provided in autoPytorch/datasets/resampling_strategy.py
94-
shuffle: Whether to shuffle the data when performing splits
95+
shuffle: Whether to shuffle the data before performing splits
9596
seed (int), (default=1): seed to be used for reproducibility.
9697
train_transforms (Optional[torchvision.transforms.Compose]):
9798
Additional Transforms to be applied to the training data
@@ -107,12 +108,14 @@ def __init__(
107108
type_check(train_tensors, val_tensors)
108109
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
109110
self.random_state = np.random.RandomState(seed=seed)
111+
self.shuffle = shuffle
112+
110113
self.resampling_strategy = resampling_strategy
111114
self.resampling_strategy_args: Dict[str, Any] = {}
112115
if resampling_strategy_args is not None:
113116
self.resampling_strategy_args = resampling_strategy_args
114117

115-
self.shuffle = self.resampling_strategy_args.get('shuffle', False)
118+
self.shuffle_split = self.resampling_strategy_args.get('shuffle', False)
116119
self.is_stratify = self.resampling_strategy_args.get('stratify', False)
117120

118121
self.task_type: Optional[str] = None
@@ -195,7 +198,7 @@ def __len__(self) -> int:
195198
return self.train_tensors[0].shape[0]
196199

197200
def _get_indices(self) -> np.ndarray:
198-
return np.arange(len(self))
201+
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))
199202

200203
def _process_resampling_strategy_args(self) -> None:
201204
if not any(isinstance(self.resampling_strategy, val_type)
@@ -238,7 +241,7 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
238241
return self.resampling_strategy(
239242
random_state=self.random_state,
240243
val_share=val_share,
241-
shuffle=self.shuffle,
244+
shuffle=self.shuffle_split,
242245
indices=self._get_indices(),
243246
labels_to_stratify=labels_to_stratify
244247
)
@@ -248,7 +251,7 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
248251
return self.resampling_strategy(
249252
random_state=self.random_state,
250253
num_splits=num_splits,
251-
shuffle=self.shuffle,
254+
shuffle=self.shuffle_split,
252255
indices=self._get_indices(),
253256
labels_to_stratify=labels_to_stratify
254257
)

autoPyTorch/datasets/image_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ImageDataset(BaseDataset):
4545
resampling_strategy_args (Optional[Dict[str, Any]]):
4646
arguments required for the chosen resampling strategy.
4747
The details are provided in autoPytorch/datasets/resampling_strategy.py
48-
shuffle: Whether to shuffle the data when performing splits
48+
shuffle: Whether to shuffle the data before performing splits
4949
seed (int), (default=1): seed to be used for reproducibility.
5050
train_transforms (Optional[torchvision.transforms.Compose]):
5151
Additional Transforms to be applied to the training data
@@ -58,6 +58,7 @@ def __init__(self,
5858
test: Optional[IMAGE_DATASET_INPUT] = None,
5959
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
6060
resampling_strategy_args: Optional[Dict[str, Any]] = None,
61+
shuffle: Optional[bool] = True,
6162
seed: Optional[int] = 42,
6263
train_transforms: Optional[torchvision.transforms.Compose] = None,
6364
val_transforms: Optional[torchvision.transforms.Compose] = None,
@@ -70,7 +71,7 @@ def __init__(self,
7071
test = _create_image_dataset(data=test)
7172
self.mean, self.std = _calc_mean_std(train=train)
7273

73-
super().__init__(train_tensors=train, val_tensors=val, test_tensors=test,
74+
super().__init__(train_tensors=train, val_tensors=val, test_tensors=test, shuffle=shuffle,
7475
resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args,
7576
seed=seed,
7677
train_transforms=train_transforms,

autoPyTorch/datasets/tabular_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class TabularDataset(BaseDataset):
5050
resampling_strategy_args (Optional[Dict[str, Any]]):
5151
arguments required for the chosen resampling strategy.
5252
The details are provided in autoPytorch/datasets/resampling_strategy.py
53-
shuffle: Whether to shuffle the data when performing splits
53+
shuffle: Whether to shuffle the data before performing splits
5454
seed (int), (default=1): seed to be used for reproducibility.
5555
train_transforms (Optional[torchvision.transforms.Compose]):
5656
Additional Transforms to be applied to the training data.
@@ -68,6 +68,7 @@ def __init__(self,
6868
Y_test: Optional[Union[np.ndarray, pd.DataFrame]] = None,
6969
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
7070
resampling_strategy_args: Optional[Dict[str, Any]] = None,
71+
shuffle: Optional[bool] = True,
7172
seed: Optional[int] = 42,
7273
train_transforms: Optional[torchvision.transforms.Compose] = None,
7374
val_transforms: Optional[torchvision.transforms.Compose] = None,
@@ -90,7 +91,7 @@ def __init__(self,
9091
self.num_features = validator.feature_validator.num_features
9192
self.categories = validator.feature_validator.categories
9293

93-
super().__init__(train_tensors=(X, Y), test_tensors=(X_test, Y_test),
94+
super().__init__(train_tensors=(X, Y), test_tensors=(X_test, Y_test), shuffle=shuffle,
9495
resampling_strategy=resampling_strategy,
9596
resampling_strategy_args=resampling_strategy_args,
9697
seed=seed, train_transforms=train_transforms,

autoPyTorch/datasets/time_series_dataset.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(self,
4141
val: Optional[TIME_SERIES_FORECASTING_INPUT] = None,
4242
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
4343
resampling_strategy_args: Optional[Dict[str, Any]] = None,
44+
shuffle: Optional[bool] = False,
4445
seed: Optional[int] = 42,
4546
train_transforms: Optional[torchvision.transforms.Compose] = None,
4647
val_transforms: Optional[torchvision.transforms.Compose] = None,
@@ -68,7 +69,7 @@ def __init__(self,
6869
target_variables=target_variables,
6970
sequence_length=sequence_length,
7071
n_steps=n_steps)
71-
super().__init__(train_tensors=train, val_tensors=val,
72+
super().__init__(train_tensors=train, val_tensors=val, shuffle=shuffle,
7273
resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args,
7374
seed=seed,
7475
train_transforms=train_transforms,
@@ -128,17 +129,15 @@ def __init__(self,
128129
_check_time_series_inputs(train=train,
129130
val=val,
130131
task_type="time_series_classification")
131-
resampling_strategy_args = {'shuffle': True}
132-
super().__init__(train_tensors=train, val_tensors=val, resampling_strategy_args=resampling_strategy_args)
132+
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
133133

134134

135135
class TimeSeriesRegressionDataset(BaseDataset):
136136
def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np.ndarray, np.ndarray]] = None):
137137
_check_time_series_inputs(train=train,
138138
val=val,
139139
task_type="time_series_regression")
140-
resampling_strategy_args = {'shuffle': True}
141-
super().__init__(train_tensors=train, val_tensors=val, resampling_strategy_args=resampling_strategy_args)
140+
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
142141

143142

144143
def _check_time_series_inputs(task_type: str,

0 commit comments

Comments
 (0)