2727
2828import pandas as pd
2929
30- from smac .runhistory .runhistory import DataOrigin , RunHistory , RunInfo , RunValue
30+ from smac .runhistory .runhistory import DataOrigin , RunHistory
3131from smac .stats .stats import Stats
3232from smac .tae import StatusType
3333
@@ -238,7 +238,7 @@ def __init__(
238238 " HyperparameterSearchSpaceUpdates got {}" .format (type (self .search_space_updates )))
239239
240240 @abstractmethod
241- def build_pipeline (self , dataset_properties : Dict [str , Any ]) -> BasePipeline :
241+ def build_pipeline (self , dataset_properties : Dict [str , BaseDatasetPropertiesType ]) -> BasePipeline :
242242 """
243243 Build pipeline according to current task
244244 and for the passed dataset properties
@@ -486,11 +486,16 @@ def _load_models(self) -> bool:
486486 raise ValueError ("Resampling strategy is needed to determine what models to load" )
487487 self .ensemble_ = self ._backend .load_ensemble (self .seed )
488488
489- if isinstance (self ._disable_file_output , List ):
490- disabled_file_outputs = self ._disable_file_output
489+ # TODO: remove this code after `fit_pipeline` is rebased.
490+ if hasattr (self , '_disable_file_output' ):
491+ if isinstance (self ._disable_file_output , List ):
492+ disabled_file_outputs = self ._disable_file_output
493+ disable_file_output = False
494+ elif isinstance (self ._disable_file_output , bool ):
495+ disable_file_output = self ._disable_file_output
496+ disabled_file_outputs = []
497+ else :
491498 disable_file_output = False
492- elif isinstance (self ._disable_file_output , bool ):
493- disable_file_output = self ._disable_file_output
494499 disabled_file_outputs = []
495500
496501 # If no ensemble is loaded, try to get the best performing model
@@ -794,18 +799,15 @@ def run_traditional_ml(
794799 learning algorithm runs over the time limit.
795800 """
796801 assert self ._logger is not None # for mypy compliancy
797- if STRING_TO_TASK_TYPES [self .task_type ] in REGRESSION_TASKS :
798- self ._logger .warning ("Traditional Pipeline is not enabled for regression. Skipping..." )
799- else :
800- traditional_task_name = 'runTraditional'
801- self ._stopwatch .start_task (traditional_task_name )
802- elapsed_time = self ._stopwatch .wall_elapsed (current_task_name )
803- time_for_traditional = int (runtime_limit - elapsed_time )
804- self ._do_traditional_prediction (
805- func_eval_time_limit_secs = func_eval_time_limit_secs ,
806- time_left = time_for_traditional ,
807- )
808- self ._stopwatch .stop_task (traditional_task_name )
802+ traditional_task_name = 'runTraditional'
803+ self ._stopwatch .start_task (traditional_task_name )
804+ elapsed_time = self ._stopwatch .wall_elapsed (current_task_name )
805+ time_for_traditional = int (runtime_limit - elapsed_time )
806+ self ._do_traditional_prediction (
807+ func_eval_time_limit_secs = func_eval_time_limit_secs ,
808+ time_left = time_for_traditional ,
809+ )
810+ self ._stopwatch .stop_task (traditional_task_name )
809811
810812 def _search (
811813 self ,
@@ -1165,22 +1167,7 @@ def _search(
11651167 self ._logger .info ("Starting Shutdown" )
11661168
11671169 if proc_ensemble is not None :
1168- self ._results_manager .ensemble_performance_history = list (proc_ensemble .history )
1169-
1170- if len (proc_ensemble .futures ) > 0 :
1171- # Also add ensemble runs that did not finish within smac time
1172- # and add them into the ensemble history
1173- self ._logger .info ("Ensemble script still running, waiting for it to finish." )
1174- result = proc_ensemble .futures .pop ().result ()
1175- if result :
1176- ensemble_history , _ , _ , _ = result
1177- self ._results_manager .ensemble_performance_history .extend (ensemble_history )
1178- self ._logger .info ("Ensemble script finished, continue shutdown." )
1179-
1180- # save the ensemble performance history file
1181- if len (self .ensemble_performance_history ) > 0 :
1182- pd .DataFrame (self .ensemble_performance_history ).to_json (
1183- os .path .join (self ._backend .internals_directory , 'ensemble_history.json' ))
1170+ self ._collect_results_ensemble (proc_ensemble )
11841171
11851172 if load_models :
11861173 self ._logger .info ("Loading models..." )
@@ -1321,7 +1308,7 @@ def fit(self,
13211308 exclude = self .exclude_components ,
13221309 search_space_updates = self .search_space_updates )
13231310 dataset_properties = dataset .get_dataset_properties (dataset_requirements )
1324- self ._backend .replace_datamanager (dataset )
1311+ self ._backend .save_datamanager (dataset )
13251312
13261313 # build pipeline
13271314 pipeline = self .build_pipeline (dataset_properties )
@@ -1339,7 +1326,6 @@ def fit(self,
13391326 self ._clean_logger ()
13401327 return pipeline
13411328
1342-
13431329 def fit_ensemble (
13441330 self ,
13451331 optimize_metric : Optional [str ] = None ,
@@ -1418,7 +1404,7 @@ def fit_ensemble(
14181404 ensemble_fit_task_name = 'EnsembleFit'
14191405 self ._stopwatch .start_task (ensemble_fit_task_name )
14201406 if enable_traditional_pipeline :
1421- if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_for_task :
1407+ if func_eval_time_limit_secs > time_for_task :
14221408 self ._logger .warning (
14231409 'Time limit for a single run is higher than total time '
14241410 'limit. Capping the limit for a single run to the total '
@@ -1459,12 +1445,8 @@ def fit_ensemble(
14591445 )
14601446
14611447 manager .build_ensemble (self ._dask_client )
1462- future = manager .futures .pop ()
1463- result = future .result ()
1464- if result is None :
1465- raise ValueError ("Errors occurred while building the ensemble - please"
1466- " check the log file and command line output for error messages." )
1467- self .ensemble_performance_history , _ , _ , _ = result
1448+ if manager is not None :
1449+ self ._collect_results_ensemble (manager )
14681450
14691451 if load_models :
14701452 self ._load_models ()
@@ -1542,6 +1524,31 @@ def _init_ensemble_builder(
15421524
15431525 return proc_ensemble
15441526
1527+ def _collect_results_ensemble (
1528+ self ,
1529+ manager : EnsembleBuilderManager
1530+ ) -> None :
1531+
1532+ if self ._logger is None :
1533+ raise ValueError ("logger should be initialized to fit ensemble" )
1534+
1535+ self ._results_manager .ensemble_performance_history = list (manager .history )
1536+
1537+ if len (manager .futures ) > 0 :
1538+ # Also add ensemble runs that did not finish within smac time
1539+ # and add them into the ensemble history
1540+ self ._logger .info ("Ensemble script still running, waiting for it to finish." )
1541+ result = manager .futures .pop ().result ()
1542+ if result :
1543+ ensemble_history , _ , _ , _ = result
1544+ self ._results_manager .ensemble_performance_history .extend (ensemble_history )
1545+ self ._logger .info ("Ensemble script finished, continue shutdown." )
1546+
1547+ # save the ensemble performance history file
1548+ if len (self .ensemble_performance_history ) > 0 :
1549+ pd .DataFrame (self .ensemble_performance_history ).to_json (
1550+ os .path .join (self ._backend .internals_directory , 'ensemble_history.json' ))
1551+
15451552 def predict (
15461553 self ,
15471554 X_test : np .ndarray ,
0 commit comments