Skip to content

Commit bef4323

Browse files
committed
[fix] Fix pytest errors
1 parent 93e6862 commit bef4323

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ def holdout_validation(
3232
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
3333
) -> List[Tuple[np.ndarray, np.ndarray]]:
3434

35+
""" SKLearn requires shuffle=True for stratify """
3536
train, val = train_test_split(
3637
indices, test_size=val_share,
37-
shuffle=shuffle, random_state=random_state,
38+
shuffle=shuffle if labels_to_stratify is None else True,
39+
random_state=random_state,
3840
stratify=labels_to_stratify
3941
)
4042
return [(train, val)]

0 commit comments

Comments
 (0)