99import time
1010import typing
1111import unittest .mock
12+ import uuid
1213import warnings
1314from abc import abstractmethod
1415from typing import Any , Callable , Dict , List , Optional , Union , cast
@@ -122,21 +123,24 @@ class BaseTask:
122123 """
123124
124125 def __init__ (
125- self ,
126- seed : int = 1 ,
127- n_jobs : int = 1 ,
128- logging_config : Optional [Dict ] = None ,
129- ensemble_size : int = 50 ,
130- ensemble_nbest : int = 50 ,
131- max_models_on_disc : int = 50 ,
132- temporary_directory : Optional [str ] = None ,
133- output_directory : Optional [str ] = None ,
134- delete_tmp_folder_after_terminate : bool = True ,
135- delete_output_folder_after_terminate : bool = True ,
136- include_components : Optional [Dict ] = None ,
137- exclude_components : Optional [Dict ] = None ,
138- backend : Optional [Backend ] = None ,
139- search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None
126+ self ,
127+ seed : int = 1 ,
128+ n_jobs : int = 1 ,
129+ logging_config : Optional [Dict ] = None ,
130+ ensemble_size : int = 50 ,
131+ ensemble_nbest : int = 50 ,
132+ max_models_on_disc : int = 50 ,
133+ temporary_directory : Optional [str ] = None ,
134+ output_directory : Optional [str ] = None ,
135+ delete_tmp_folder_after_terminate : bool = True ,
136+ delete_output_folder_after_terminate : bool = True ,
137+ include_components : Optional [Dict ] = None ,
138+ exclude_components : Optional [Dict ] = None ,
139+ backend : Optional [Backend ] = None ,
140+ resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes .holdout_validation ,
141+ resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
142+ search_space_updates : Optional [HyperparameterSearchSpaceUpdates ] = None ,
143+ task_type : Optional [str ] = None
140144 ) -> None :
141145 self .seed = seed
142146 self .n_jobs = n_jobs
@@ -157,14 +161,14 @@ def __init__(
157161 delete_tmp_folder_after_terminate = delete_tmp_folder_after_terminate ,
158162 delete_output_folder_after_terminate = delete_output_folder_after_terminate ,
159163 )
164+ self .task_type = task_type
160165 self ._stopwatch = StopWatch ()
161166
162167 self .pipeline_options = replace_string_bool_to_bool (json .load (open (
163168 os .path .join (os .path .dirname (__file__ ), '../configs/default_pipeline_options.json' ))))
164169
165170 self .search_space : Optional [ConfigurationSpace ] = None
166171 self ._dataset_requirements : Optional [List [FitRequirement ]] = None
167- self .task_type : Optional [str ] = None
168172 self ._metric : Optional [autoPyTorchMetric ] = None
169173 self ._logger : Optional [PicklableClientLogger ] = None
170174 self .run_history : Optional [RunHistory ] = None
@@ -176,7 +180,8 @@ def __init__(
176180 self ._logger_port = logging .handlers .DEFAULT_TCP_LOGGING_PORT
177181
178182 # Store the resampling strategy from the dataset, to load models as needed
179- self .resampling_strategy = None # type: Optional[Union[CrossValTypes, HoldoutValTypes]]
183+ self .resampling_strategy = resampling_strategy
184+ self .resampling_strategy_args = resampling_strategy_args
180185
181186 self .stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]
182187
@@ -287,7 +292,7 @@ def _get_logger(self, name: str) -> PicklableClientLogger:
287292 output_dir = self ._backend .temporary_directory ,
288293 )
289294
290- # As Auto-sklearn works with distributed process,
295+ # As AutoPyTorch works with distributed process,
291296 # we implement a logger server that can receive tcp
292297 # pickled messages. They are unpickled and processed locally
293298 # under the above logging configuration setting
@@ -398,20 +403,16 @@ def _close_dask_client(self) -> None:
398403 self ._is_dask_client_internally_created = False
399404 del self ._is_dask_client_internally_created
400405
401- def _load_models (self , resampling_strategy : Optional [Union [CrossValTypes , HoldoutValTypes ]]
402- ) -> bool :
406+ def _load_models (self ) -> bool :
403407
404408 """
405409 Loads the models saved in the temporary directory
406410 during the smac run and the final ensemble created
407- Args:
408- resampling_strategy (Union[CrossValTypes, HoldoutValTypes]): resampling strategy used to split the data
409- and to validate the performance of a candidate pipeline
410411
411412 Returns:
412413 None
413414 """
414- if resampling_strategy is None :
415+ if self . resampling_strategy is None :
415416 raise ValueError ("Resampling strategy is needed to determine what models to load" )
416417 self .ensemble_ = self ._backend .load_ensemble (self .seed )
417418
@@ -422,10 +423,10 @@ def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, Holdou
422423 if self .ensemble_ :
423424 identifiers = self .ensemble_ .get_selected_model_identifiers ()
424425 self .models_ = self ._backend .load_models_by_identifiers (identifiers )
425- if isinstance (resampling_strategy , CrossValTypes ):
426+ if isinstance (self . resampling_strategy , CrossValTypes ):
426427 self .cv_models_ = self ._backend .load_cv_models_by_identifiers (identifiers )
427428
428- if isinstance (resampling_strategy , CrossValTypes ):
429+ if isinstance (self . resampling_strategy , CrossValTypes ):
429430 if len (self .cv_models_ ) == 0 :
430431 raise ValueError ('No models fitted!' )
431432
@@ -610,10 +611,10 @@ def _do_traditional_prediction(self, num_run: int, time_for_traditional: int) ->
610611 )
611612 return num_run
612613
613- def search (
614+ def _search (
614615 self ,
615- dataset : BaseDataset ,
616616 optimize_metric : str ,
617+ dataset : BaseDataset ,
617618 budget_type : Optional [str ] = None ,
618619 budget : Optional [float ] = None ,
619620 total_walltime_limit : int = 100 ,
@@ -638,6 +639,7 @@ def search(
638639 The argument that will provide the dataset splits. It is
639640 a subclass of the base dataset object which can
640641 generate the splits based on different restrictions.
642+ Providing X_train, y_train and dataset together is not supported.
641643 optimize_metric (str): name of the metric that is used to
642644 evaluate a pipeline.
643645 budget_type (Optional[str]):
@@ -692,6 +694,7 @@ def search(
692694 self
693695
694696 """
697+
695698 if self .task_type != dataset .task_type :
696699 raise ValueError ("Incompatible dataset entered for current task,"
697700 "expected dataset to have task type :{} got "
@@ -705,8 +708,8 @@ def search(
705708 dataset_properties = dataset .get_dataset_properties (dataset_requirements )
706709 self ._stopwatch .start_task (experiment_task_name )
707710 self .dataset_name = dataset .dataset_name
708- self .resampling_strategy = dataset . resampling_strategy
709- self ._logger = self ._get_logger (self .dataset_name )
711+ if self ._logger is None :
712+ self ._logger = self ._get_logger (self .dataset_name )
710713 self ._all_supported_metrics = all_supported_metrics
711714 self ._disable_file_output = disable_file_output
712715 self ._memory_limit = memory_limit
@@ -869,7 +872,7 @@ def search(
869872
870873 if load_models :
871874 self ._logger .info ("Loading models..." )
872- self ._load_models (dataset . resampling_strategy )
875+ self ._load_models ()
873876 self ._logger .info ("Finished loading models..." )
874877
875878 # Clean up the logger
@@ -906,8 +909,11 @@ def refit(
906909 Returns:
907910 self
908911 """
912+ if self .dataset_name is None :
913+ self .dataset_name = str (uuid .uuid1 (clock_seq = os .getpid ()))
909914
910- self ._logger = self ._get_logger (dataset .dataset_name )
915+ if self ._logger is None :
916+ self ._logger = self ._get_logger (self .dataset_name )
911917
912918 dataset_requirements = get_dataset_requirements (
913919 info = self ._get_required_dataset_properties (dataset ))
@@ -927,7 +933,7 @@ def refit(
927933 })
928934 X .update ({** self .pipeline_options , ** budget_config })
929935 if self .models_ is None or len (self .models_ ) == 0 or self .ensemble_ is None :
930- self ._load_models (dataset . resampling_strategy )
936+ self ._load_models ()
931937
932938 # Refit is not applicable when ensemble_size is set to zero.
933939 if self .ensemble_ is None :
@@ -973,7 +979,11 @@ def fit(self,
973979 Returns:
974980 (BasePipeline): fitted pipeline
975981 """
976- self ._logger = self ._get_logger (dataset .dataset_name )
982+ if self .dataset_name is None :
983+ self .dataset_name = str (uuid .uuid1 (clock_seq = os .getpid ()))
984+
985+ if self ._logger is None :
986+ self ._logger = self ._get_logger (self .dataset_name )
977987
978988 # get dataset properties
979989 dataset_requirements = get_dataset_requirements (
@@ -1025,7 +1035,7 @@ def predict(
10251035 if self ._logger is None :
10261036 self ._logger = self ._get_logger ("Predict-Logger" )
10271037
1028- if self .ensemble_ is None and not self ._load_models (self . resampling_strategy ):
1038+ if self .ensemble_ is None and not self ._load_models ():
10291039 raise ValueError ("No ensemble found. Either fit has not yet "
10301040 "been called or no ensemble was fitted" )
10311041
@@ -1084,9 +1094,6 @@ def score(
10841094 Returns:
10851095 Dict[str, float]: Value of the evaluation metric calculated on the test set.
10861096 """
1087- if isinstance (y_test , pd .Series ):
1088- y_test = y_test .to_numpy (dtype = np .float )
1089-
10901097 if self ._metric is None :
10911098 raise ValueError ("No metric found. Either fit/search has not been called yet "
10921099 "or AutoPyTorch failed to infer a metric from the dataset " )
0 commit comments