88from autoPyTorch .datasets .resampling_strategy import (
99 CrossValTypes ,
1010 HoldoutValTypes ,
11- get_cross_validators ,
12- get_holdout_validators
11+ CrossValFuncs ,
12+ HoldOutFuncs
1313)
1414
1515TIME_SERIES_FORECASTING_INPUT = Tuple [np .ndarray , np .ndarray ] # currently only numpy arrays are supported
@@ -60,8 +60,8 @@ def __init__(self,
6060 train_transforms = train_transforms ,
6161 val_transforms = val_transforms ,
6262 )
63- self .cross_validators = get_cross_validators (CrossValTypes .time_series_cross_validation )
64- self .holdout_validators = get_holdout_validators (HoldoutValTypes .holdout_validation )
63+ self .cross_validators = CrossValFuncs . get_cross_validators (CrossValTypes .time_series_cross_validation )
64+ self .holdout_validators = HoldOutFuncs . get_holdout_validators (HoldoutValTypes .holdout_validation )
6565
6666
6767def _check_time_series_forecasting_inputs (target_variables : Tuple [int ],
@@ -117,13 +117,13 @@ def __init__(self,
117117 val = val ,
118118 task_type = "time_series_classification" )
119119 super ().__init__ (train_tensors = train , val_tensors = val , shuffle = True )
120- self .cross_validators = get_cross_validators (
120+ self .cross_validators = CrossValFuncs . get_cross_validators (
121121 CrossValTypes .stratified_k_fold_cross_validation ,
122122 CrossValTypes .k_fold_cross_validation ,
123123 CrossValTypes .shuffle_split_cross_validation ,
124124 CrossValTypes .stratified_shuffle_split_cross_validation
125125 )
126- self .holdout_validators = get_holdout_validators (
126+ self .holdout_validators = HoldOutFuncs . get_holdout_validators (
127127 HoldoutValTypes .holdout_validation ,
128128 HoldoutValTypes .stratified_holdout_validation
129129 )
@@ -135,11 +135,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np.
135135 val = val ,
136136 task_type = "time_series_regression" )
137137 super ().__init__ (train_tensors = train , val_tensors = val , shuffle = True )
138- self .cross_validators = get_cross_validators (
138+ self .cross_validators = CrossValFuncs . get_cross_validators (
139139 CrossValTypes .k_fold_cross_validation ,
140140 CrossValTypes .shuffle_split_cross_validation
141141 )
142- self .holdout_validators = get_holdout_validators (
142+ self .holdout_validators = HoldOutFuncs . get_holdout_validators (
143143 HoldoutValTypes .holdout_validation
144144 )
145145
0 commit comments