1- import os
2- import uuid
3- from typing import Any , Callable , Dict , List , Optional , Union
1+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
42
53import numpy as np
64
1311 TASK_TYPES_TO_STRING ,
1412)
1513from autoPyTorch .data .tabular_validator import TabularInputValidator
14+ from autoPyTorch .datasets .base_dataset import BaseDatasetPropertiesType
1615from autoPyTorch .datasets .resampling_strategy import (
1716 CrossValTypes ,
1817 HoldoutValTypes ,
1918)
2019from autoPyTorch .datasets .tabular_dataset import TabularDataset
20+ from autoPyTorch .evaluation .utils import DisableFileOutputParameters
2121from autoPyTorch .pipeline .tabular_classification import TabularClassificationPipeline
2222from autoPyTorch .utils .hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
2323
@@ -54,13 +54,16 @@ class TabularClassificationTask(BaseTask):
5454 delete_tmp_folder_after_terminate (bool):
5555 Determines whether to delete the temporary directory,
5656 when finished
57- include_components (Optional[Dict]):
58- If None, all possible components are used.
59- Otherwise specifies set of components to use.
60- exclude_components (Optional[Dict]):
61- If None, all possible components are used.
62- Otherwise specifies set of components not to use.
63- Incompatible with include components.
57+ include_components (Optional[Dict[str, Any]]):
58+ Dictionary containing components to include. Key is the node
59+ name and Value is an Iterable of the names of the components
60+ to include. Only these components will be present in the
61+ search space.
62+ exclude_components (Optional[Dict[str, Any]]):
63+ Dictionary containing components to exclude. Key is the node
64+ name and Value is an Iterable of the names of the components
65+ to exclude. All except these components will be present in
66+ the search space.
6467 search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
6568 search space updates that can be used to modify the search
6669 space of particular components or choice modules of the pipeline
@@ -78,8 +81,8 @@ def __init__(
7881 output_directory : Optional [str ] = None ,
7982 delete_tmp_folder_after_terminate : bool = True ,
8083 delete_output_folder_after_terminate : bool = True ,
81- include_components : Optional [Dict ] = None ,
82- exclude_components : Optional [Dict ] = None ,
84+ include_components : Optional [Dict [ str , Any ] ] = None ,
85+ exclude_components : Optional [Dict [ str , Any ] ] = None ,
8386 resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
8487 resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
8588 backend : Optional [Backend ] = None ,
@@ -106,18 +109,109 @@ def __init__(
106109 task_type = TASK_TYPES_TO_STRING [TABULAR_CLASSIFICATION ],
107110 )
108111
109- def build_pipeline (self , dataset_properties : Dict [str , Any ]) -> TabularClassificationPipeline :
112+ def build_pipeline (
113+ self ,
114+ dataset_properties : Dict [str , BaseDatasetPropertiesType ],
115+ include_components : Optional [Dict [str , Any ]] = None ,
116+ exclude_components : Optional [Dict [str , Any ]] = None ,
117+ search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None
118+ ) -> TabularClassificationPipeline :
110119 """
111- Build pipeline according to current task and for the passed dataset properties
120+ Build pipeline according to current task
121+ and for the passed dataset properties
112122
113123 Args:
114- dataset_properties (Dict[str,Any])
124+ dataset_properties (Dict[str, Any]):
125+ Characteristics of the dataset to guide the pipeline
126+ choices of components
127+ include_components (Optional[Dict[str, Any]]):
128+ Dictionary containing components to include. Key is the node
129+ name and Value is an Iterable of the names of the components
130+ to include. Only these components will be present in the
131+ search space.
132+ exclude_components (Optional[Dict[str, Any]]):
133+ Dictionary containing components to exclude. Key is the node
134+ name and Value is an Iterable of the names of the components
135+ to exclude. All except these components will be present in
136+ the search space.
137+ search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
138+ Search space updates that can be used to modify the search
139+ space of particular components or choice modules of the pipeline
115140
116141 Returns:
117- TabularClassificationPipeline:
118- Pipeline compatible with the given dataset properties.
142+ TabularClassificationPipeline
143+
144+ """
145+ return TabularClassificationPipeline (dataset_properties = dataset_properties ,
146+ include = include_components ,
147+ exclude = exclude_components ,
148+ search_space_updates = search_space_updates )
149+
150+ def _get_dataset_input_validator (
151+ self ,
152+ X_train : Union [List , pd .DataFrame , np .ndarray ],
153+ y_train : Union [List , pd .DataFrame , np .ndarray ],
154+ X_test : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
155+ y_test : Optional [Union [List , pd .DataFrame , np .ndarray ]] = None ,
156+ resampling_strategy : Optional [Union [CrossValTypes , HoldoutValTypes ]] = None ,
157+ resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
158+ dataset_name : Optional [str ] = None ,
159+ ) -> Tuple [TabularDataset , TabularInputValidator ]:
119160 """
120- return TabularClassificationPipeline (dataset_properties = dataset_properties )
161+ Returns an object of `TabularDataset` and an object of
162+ `TabularInputValidator` according to the current task.
163+
164+ Args:
165+ X_train (Union[List, pd.DataFrame, np.ndarray]):
166+ Training feature set.
167+ y_train (Union[List, pd.DataFrame, np.ndarray]):
168+ Training target set.
169+ X_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
170+ Testing feature set
171+ y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
172+ Testing target set
173+ resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
174+ Strategy to split the training data. if None, uses
175+ HoldoutValTypes.holdout_validation.
176+ resampling_strategy_args (Optional[Dict[str, Any]]):
177+ arguments required for the chosen resampling strategy. If None, uses
178+ the default values provided in DEFAULT_RESAMPLING_PARAMETERS
179+ in ```datasets/resampling_strategy.py```.
180+ dataset_name (Optional[str]):
181+ name of the dataset, used as experiment name.
182+ Returns:
183+ TabularDataset:
184+ the dataset object.
185+ TabularInputValidator:
186+ the input validator fitted on the data.
187+ """
188+
189+ resampling_strategy = resampling_strategy if resampling_strategy is not None else self .resampling_strategy
190+ resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \
191+ self .resampling_strategy_args
192+
193+ # Create a validator object to make sure that the data provided by
194+ # the user matches the autopytorch requirements
195+ InputValidator = TabularInputValidator (
196+ is_classification = True ,
197+ logger_port = self ._logger_port ,
198+ )
199+
200+ # Fit a input validator to check the provided data
201+ # Also, an encoder is fit to both train and test data,
202+ # to prevent unseen categories during inference
203+ InputValidator .fit (X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test )
204+
205+ dataset = TabularDataset (
206+ X = X_train , Y = y_train ,
207+ X_test = X_test , Y_test = y_test ,
208+ validator = InputValidator ,
209+ resampling_strategy = resampling_strategy ,
210+ resampling_strategy_args = resampling_strategy_args ,
211+ dataset_name = dataset_name
212+ )
213+
214+ return dataset , InputValidator
121215
122216 def search (
123217 self ,
@@ -138,7 +232,7 @@ def search(
138232 get_smac_object_callback : Optional [Callable ] = None ,
139233 all_supported_metrics : bool = True ,
140234 precision : int = 32 ,
141- disable_file_output : List = [] ,
235+ disable_file_output : Optional [ List [ Union [ str , DisableFileOutputParameters ]]] = None ,
142236 load_models : bool = True ,
143237 portfolio_selection : Optional [str ] = None ,
144238 ) -> 'BaseTask' :
@@ -237,10 +331,10 @@ def search(
237331 precision (int: default=32):
238332 Numeric precision used when loading ensemble data.
239333 Can be either '16', '32' or '64'.
240- disable_file_output (Union[bool, List ]):
241- If True, disable model and prediction output.
242- Can also be used as a list to pass more fine-grained
243- information on what to save. Allowed elements in the list are:
334+ disable_file_output (Optional[List[ Union[str, DisableFileOutputParameters]] ]):
335+ Used as a list to pass more fine-grained
336+ information on what to save. Must be a member of `DisableFileOutputParameters`.
337+ Allowed elements in the list are:
244338
245339 + `y_optimization`:
246340 do not save the predictions for the optimization set,
@@ -253,6 +347,9 @@ def search(
253347 pipelines fit on each fold.
254348 + `y_test`:
255349 do not save the predictions for the test set.
350+ + `all`:
351+ do not save any of the above.
352+ For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`.
256353 load_models (bool: default=True):
257354 Whether to load the models after fitting AutoPyTorch.
258355 portfolio_selection (Optional[str]):
@@ -269,32 +366,15 @@ def search(
269366 self
270367
271368 """
272- if dataset_name is None :
273- dataset_name = str (uuid .uuid1 (clock_seq = os .getpid ()))
274369
275- # we have to create a logger for at this point for the validator
276- self ._logger = self ._get_logger (dataset_name )
277-
278- # Create a validator object to make sure that the data provided by
279- # the user matches the autopytorch requirements
280- self .InputValidator = TabularInputValidator (
281- is_classification = True ,
282- logger_port = self ._logger_port ,
283- )
284-
285- # Fit a input validator to check the provided data
286- # Also, an encoder is fit to both train and test data,
287- # to prevent unseen categories during inference
288- self .InputValidator .fit (X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test )
289-
290- self .dataset = TabularDataset (
291- X = X_train , Y = y_train ,
292- X_test = X_test , Y_test = y_test ,
293- validator = self .InputValidator ,
294- dataset_name = dataset_name ,
370+ self .dataset , self .InputValidator = self ._get_dataset_input_validator (
371+ X_train = X_train ,
372+ y_train = y_train ,
373+ X_test = X_test ,
374+ y_test = y_test ,
295375 resampling_strategy = self .resampling_strategy ,
296376 resampling_strategy_args = self .resampling_strategy_args ,
297- )
377+ dataset_name = dataset_name )
298378
299379 return self ._search (
300380 dataset = self .dataset ,
@@ -333,7 +413,7 @@ def predict(
333413 """
334414 if self .InputValidator is None or not self .InputValidator ._is_fitted :
335415 raise ValueError ("predict() is only supported after calling search. Kindly call first "
336- "the estimator fit () method." )
416+ "the estimator search () method." )
337417
338418 X_test = self .InputValidator .feature_validator .transform (X_test )
339419 predicted_probabilities = super ().predict (X_test , batch_size = batch_size ,
@@ -353,6 +433,6 @@ def predict_proba(self,
353433 batch_size : Optional [int ] = None , n_jobs : int = 1 ) -> np .ndarray :
354434 if self .InputValidator is None or not self .InputValidator ._is_fitted :
355435 raise ValueError ("predict() is only supported after calling search. Kindly call first "
356- "the estimator fit () method." )
436+ "the estimator search () method." )
357437 X_test = self .InputValidator .feature_validator .transform (X_test )
358438 return super ().predict (X_test , batch_size = batch_size , n_jobs = n_jobs )
0 commit comments