@@ -499,7 +499,7 @@ def _do_dummy_prediction(self) -> None:
499499 memory_limit = int (math .ceil (memory_limit ))
500500
501501 scenario_mock = unittest .mock .Mock ()
502- scenario_mock .wallclock_limit = self ._time_for_task
502+ scenario_mock .wallclock_limit = self ._total_walltime_limit
503503 # This stats object is a hack - maybe the SMAC stats object should
504504 # already be generated here!
505505 stats = Stats (scenario_mock )
@@ -518,7 +518,7 @@ def _do_dummy_prediction(self) -> None:
518518 all_supported_metrics = self ._all_supported_metrics
519519 )
520520
521- status , cost , runtime , additional_info = ta .run (self .num_run , cutoff = self ._time_for_task )
521+ status , cost , runtime , additional_info = ta .run (self .num_run , cutoff = self ._total_walltime_limit )
522522 if status == StatusType .SUCCESS :
523523 self ._logger .info ("Finished creating dummy predictions." )
524524 else :
@@ -552,8 +552,7 @@ def _do_dummy_prediction(self) -> None:
552552 % (str (status ), str (additional_info ))
553553 )
554554
555- def _do_traditional_prediction (self , time_left : int , func_eval_time_limit_secs : int
556- ) -> None :
555+ def _do_traditional_prediction (self , time_left : int ) -> None :
557556 """
558557 Fits traditional machine learning algorithms to the provided dataset, while
559558 complying with time resource allocation.
@@ -596,8 +595,8 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
596595
597596 # Only launch a task if there is time
598597 start_time = time .time ()
599- if time_left >= func_eval_time_limit_secs :
600- self ._logger .info (f"{ n_r } : Started fitting { classifier } with cutoff={ func_eval_time_limit_secs } " )
598+ if time_left >= self . _func_eval_time_limit_secs :
599+ self ._logger .info (f"{ n_r } : Started fitting { classifier } with cutoff={ self . _func_eval_time_limit_secs } " )
601600 scenario_mock = unittest .mock .Mock ()
602601 scenario_mock .wallclock_limit = time_left
603602 # This stats object is a hack - maybe the SMAC stats object should
@@ -621,7 +620,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
621620 classifier ,
622621 self ._dask_client .submit (
623622 ta .run , config = classifier ,
624- cutoff = func_eval_time_limit_secs ,
623+ cutoff = self . _func_eval_time_limit_secs ,
625624 )
626625 ])
627626
@@ -640,7 +639,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
640639
641640 # How many workers to wait before starting fitting the next iteration
642641 workers_to_wait = 1
643- if n_r >= total_number_classifiers - 1 or time_left <= func_eval_time_limit_secs :
642+ if n_r >= total_number_classifiers - 1 or time_left <= self . _func_eval_time_limit_secs :
644643 # If on the last iteration, flush out all tasks
645644 workers_to_wait = len (dask_futures )
646645
@@ -675,7 +674,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
675674 time_left -= int (time .time () - start_time )
676675
677676 # Exit if no more time is available for a new classifier
678- if time_left < func_eval_time_limit_secs :
677+ if time_left < self . _func_eval_time_limit_secs :
679678 self ._logger .warning ("Not enough time to fit all traditional machine learning models."
680679 "Please consider increasing the run time to further improve performance." )
681680 break
@@ -686,36 +685,30 @@ def _run_dummy_predictions(self) -> None:
686685 self ._do_dummy_prediction ()
687686 self ._stopwatch .stop_task (dummy_task_name )
688687
689- def _run_traditional_ml (self ,
690- enable_traditional_pipeline : bool ,
691- func_eval_time_limit_secs : Optional [int ] = None ) -> None :
688+ def _run_traditional_ml (self ) -> None :
692689 """We would like to obtain training time for at least 1 Neural network in SMAC"""
690+ assert self ._logger is not None
693691
694- if enable_traditional_pipeline :
695- if STRING_TO_TASK_TYPES [self .task_type ] in REGRESSION_TASKS :
696- self ._logger .warning ("Traditional Pipeline is not enabled for regression. Skipping..." )
697- else :
698- traditional_task_name = 'runTraditional'
699- self ._stopwatch .start_task (traditional_task_name )
700- elapsed_time = self ._stopwatch .wall_elapsed (self .dataset_name )
692+ if STRING_TO_TASK_TYPES [self .task_type ] in REGRESSION_TASKS :
693+ self ._logger .warning ("Traditional Pipeline is not enabled for regression. Skipping..." )
694+ else :
695+ traditional_task_name = 'runTraditional'
696+ self ._stopwatch .start_task (traditional_task_name )
697+ elapsed_time = self ._stopwatch .wall_elapsed (self .dataset_name )
701698
702- time_for_traditional = int (
703- self ._time_for_task - elapsed_time - func_eval_time_limit_secs
704- )
705- self .num_run = self ._do_traditional_prediction (
706- func_eval_time_limit_secs = func_eval_time_limit_secs ,
707- time_left = time_for_traditional ,
708- )
709- self ._stopwatch .stop_task (traditional_task_name )
699+ time_for_traditional = int (
700+ self ._total_walltime_limit - elapsed_time - self ._func_eval_time_limit_secs
701+ )
702+ self ._do_traditional_prediction (time_left = time_for_traditional )
703+ self ._stopwatch .stop_task (traditional_task_name )
710704
711- def _run_ensemble (self ,
712- dataset : BaseDataset ,
713- optimize_metric : str ,
714- total_walltime_limit : int ,
705+ def _run_ensemble (self , dataset : BaseDataset , optimize_metric : str ,
715706 precision : int ) -> EnsembleBuilderManager :
716707
708+ assert self ._logger is not None
709+
717710 elapsed_time = self ._stopwatch .wall_elapsed (self .dataset_name )
718- time_left_for_ensembles = max (0 , total_walltime_limit - elapsed_time )
711+ time_left_for_ensembles = max (0 , self . _total_walltime_limit - elapsed_time )
719712 proc_ensemble = None
720713 if time_left_for_ensembles <= 0 and self .ensemble_size > 0 :
721714 raise ValueError ("Could not run ensemble builder because there "
@@ -734,25 +727,20 @@ def _run_ensemble(self,
734727 dataset_name = dataset .dataset_name ,
735728 output_type = STRING_TO_OUTPUT_TYPES [dataset .output_type ],
736729 task_type = STRING_TO_TASK_TYPES [self .task_type ],
737- metrics = [self ._metric ],
738- opt_metric = optimize_metric ,
730+ metrics = [self ._metric ], opt_metric = optimize_metric ,
739731 ensemble_size = self .ensemble_size ,
740732 ensemble_nbest = self .ensemble_nbest ,
741733 max_models_on_disc = self .max_models_on_disc ,
742- seed = self .seed ,
743- max_iterations = None ,
744- read_at_most = sys .maxsize ,
745734 ensemble_memory_limit = self ._memory_limit ,
746- random_state = self .seed ,
747- precision = precision ,
748- logger_port = self ._logger_port ,
735+ seed = self . seed , max_iterations = None , random_state = self .seed ,
736+ read_at_most = sys . maxsize , precision = precision ,
737+ logger_port = self ._logger_port
749738 )
750739 self ._stopwatch .stop_task (ensemble_task_name )
751740
752741 return proc_ensemble
753742
754- def _get_budget_config (self ,
755- budget_type : Optional [str ] = None ,
743+ def _get_budget_config (self , budget_type : Optional [str ] = None ,
756744 budget : Optional [float ] = None ) -> Dict [str , Union [float , str ]]:
757745
758746 budget_config : Dict [str , Union [float , str ]] = {}
@@ -764,13 +752,18 @@ def _get_budget_config(self,
764752
765753 return budget_config
766754
767- def _start_smac (self , proc_smac : AutoMLSMBO ):
755+ def _start_smac (self , proc_smac : AutoMLSMBO ) -> None :
756+ assert self ._logger is not None
757+
768758 try :
769759 self .run_history , self .trajectory , budget_type = \
770760 proc_smac .run_smbo ()
771761 trajectory_filename = os .path .join (
772762 self ._backend .get_smac_output_directory_for_run (self .seed ),
773763 'trajectory.json' )
764+
765+ assert self .trajectory is not None
766+
774767 saveable_trajectory = \
775768 [list (entry [:2 ]) + [entry [2 ].get_dictionary ()] + list (entry [3 :])
776769 for entry in self .trajectory ]
@@ -784,20 +777,17 @@ def _start_smac(self, proc_smac: AutoMLSMBO):
784777 except Exception as e :
785778 self ._logger .warning (f"Could not save { trajectory_filename } due to { e } ..." )
786779
787- def _run_smac (self ,
788- dataset : BaseDataset ,
789- proc_ensemble : EnsembleBuilderManager ,
790- total_walltime_limit : int ,
791- budget_type : Optional [str ] = None ,
792- budget : Optional [float ] = None ,
793- func_eval_time_limit_secs : Optional [int ] = None ,
780+ def _run_smac (self , dataset : BaseDataset , proc_ensemble : EnsembleBuilderManager ,
781+ budget_type : Optional [str ] = None , budget : Optional [float ] = None ,
794782 get_smac_object_callback : Optional [Callable ] = None ,
795783 smac_scenario_args : Optional [Dict [str , Any ]] = None ) -> None :
796784
785+ assert self ._logger is not None
786+
797787 smac_task_name = 'runSMAC'
798788 self ._stopwatch .start_task (smac_task_name )
799789 elapsed_time = self ._stopwatch .wall_elapsed (self .experiment_task_name )
800- time_left_for_smac = max (0 , total_walltime_limit - elapsed_time )
790+ time_left_for_smac = max (0 , self . _total_walltime_limit - elapsed_time )
801791
802792 self ._logger .info (f"Run SMAC with { time_left_for_smac :.2f} sec time left" )
803793 if time_left_for_smac <= 0 :
@@ -808,14 +798,12 @@ def _run_smac(self,
808798 config_space = self .search_space ,
809799 dataset_name = dataset .dataset_name ,
810800 backend = self ._backend ,
811- total_walltime_limit = total_walltime_limit ,
812- func_eval_time_limit_secs = func_eval_time_limit_secs ,
801+ total_walltime_limit = self . _total_walltime_limit ,
802+ func_eval_time_limit_secs = self . _func_eval_time_limit_secs ,
813803 dask_client = self ._dask_client ,
814804 memory_limit = self ._memory_limit ,
815- n_jobs = self .n_jobs ,
816- watcher = self ._stopwatch ,
817- metric = self ._metric ,
818- seed = self .seed ,
805+ n_jobs = self .n_jobs , watcher = self ._stopwatch ,
806+ metric = self ._metric , seed = self .seed ,
819807 include = self .include_components ,
820808 exclude = self .exclude_components ,
821809 disable_file_output = self ._disable_file_output ,
@@ -833,8 +821,9 @@ def _run_smac(self,
833821
834822 def _search_settings (self , dataset : BaseDataset , disable_file_output : List ,
835823 optimize_metric : str , memory_limit : Optional [int ] = 4096 ,
836- total_walltime_limit : int = 100 , all_supported_metrics : bool = True
837- ) -> None :
824+ func_eval_time_limit_secs : Optional [int ] = None ,
825+ total_walltime_limit : int = 100 ,
826+ all_supported_metrics : bool = True ) -> None :
838827
839828 """Initialise information needed for the experiment"""
840829 self .experiment_task_name = 'runSearch'
@@ -847,12 +836,13 @@ def _search_settings(self, dataset: BaseDataset, disable_file_output: List,
847836 self ._all_supported_metrics = all_supported_metrics
848837 self ._disable_file_output = disable_file_output
849838 self ._memory_limit = memory_limit
850- self ._time_for_task = total_walltime_limit
839+ self ._total_walltime_limit = total_walltime_limit
840+ self ._func_eval_time_limit_secs = func_eval_time_limit_secs
851841 self ._metric = get_metrics (
852842 names = [optimize_metric ], dataset_properties = dataset_properties )[0 ]
853843
854844 if self ._logger is None :
855- self ._logger = self ._get_logger (self .dataset_name )
845+ self ._logger = self ._get_logger (str ( self .dataset_name ) )
856846
857847 # Save start time to backend
858848 self ._backend .save_start_time (str (self .seed ))
@@ -872,36 +862,34 @@ def _search_settings(self, dataset: BaseDataset, disable_file_output: List,
872862 else :
873863 self ._is_dask_client_internally_created = False
874864
875- def _adapt_time_resource_allocation (self ,
876- total_walltime_limit : int ,
877- func_eval_time_limit_secs : Optional [int ] = None
878- ) -> int :
865+ def _adapt_time_resource_allocation (self ) -> None :
866+ assert self ._logger is not None
879867
880868 # Handle time resource allocation
881869 elapsed_time = self ._stopwatch .wall_elapsed (self .experiment_task_name )
882- time_left_for_modelfit = int (max (0 , total_walltime_limit - elapsed_time ))
883- if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_left_for_modelfit :
870+ time_left_for_modelfit = int (max (0 , self . _total_walltime_limit - elapsed_time ))
871+ if self . _func_eval_time_limit_secs is None or self . _func_eval_time_limit_secs > time_left_for_modelfit :
884872 self ._logger .warning (
885873 'Time limit for a single run is higher than total time '
886874 'limit. Capping the limit for a single run to the total '
887875 'time given to SMAC (%f)' % time_left_for_modelfit
888876 )
889- func_eval_time_limit_secs = time_left_for_modelfit
877+ self . _func_eval_time_limit_secs = time_left_for_modelfit
890878
891879 # Make sure that at least 2 models are created for the ensemble process
892- num_models = time_left_for_modelfit // func_eval_time_limit_secs
880+ num_models = time_left_for_modelfit // self . _func_eval_time_limit_secs
893881 if num_models < 2 :
894- func_eval_time_limit_secs = time_left_for_modelfit // 2
882+ self . _func_eval_time_limit_secs = time_left_for_modelfit // 2
895883 self ._logger .warning (
896884 "Capping the func_eval_time_limit_secs to {} to have "
897885 "time for a least 2 models to ensemble." .format (
898- func_eval_time_limit_secs
886+ self . _func_eval_time_limit_secs
899887 )
900888 )
901889
902- return func_eval_time_limit_secs
903-
904890 def _save_ensemble_performance_history (self , proc_ensemble : EnsembleBuilderManager ) -> None :
891+ assert self ._logger is not None
892+
905893 if len (proc_ensemble .futures ) > 0 :
906894 # Also add ensemble runs that did not finish within smac time
907895 # and add them into the ensemble history
@@ -920,6 +908,7 @@ def _save_ensemble_performance_history(self, proc_ensemble: EnsembleBuilderManag
920908 def _finish_experiment (self , proc_ensemble : EnsembleBuilderManager ,
921909 load_models : bool ) -> None :
922910
911+ assert self ._logger is not None
923912 # Wait until the ensemble process is finished to avoid shutting down
924913 # while the ensemble builder tries to access the data
925914 self ._logger .info ("Start Shutdown" )
@@ -941,23 +930,18 @@ def _finish_experiment(self, proc_ensemble: EnsembleBuilderManager,
941930 self ._logger .info ("Starting to clean up the logger" )
942931 self ._clean_logger ()
943932
944- def _search (
945- self ,
946- optimize_metric : str ,
947- dataset : BaseDataset ,
948- budget_type : Optional [str ] = None ,
949- budget : Optional [float ] = None ,
950- total_walltime_limit : int = 100 ,
951- func_eval_time_limit_secs : Optional [int ] = None ,
952- enable_traditional_pipeline : bool = True ,
953- memory_limit : Optional [int ] = 4096 ,
954- smac_scenario_args : Optional [Dict [str , Any ]] = None ,
955- get_smac_object_callback : Optional [Callable ] = None ,
956- all_supported_metrics : bool = True ,
957- precision : int = 32 ,
958- disable_file_output : List = [],
959- load_models : bool = True ,
960- ) -> 'BaseTask' :
933+ def _search (self , optimize_metric : str ,
934+ dataset : BaseDataset , budget_type : Optional [str ] = None ,
935+ budget : Optional [float ] = None ,
936+ total_walltime_limit : int = 100 ,
937+ func_eval_time_limit_secs : Optional [int ] = None ,
938+ enable_traditional_pipeline : bool = True ,
939+ memory_limit : Optional [int ] = 4096 ,
940+ smac_scenario_args : Optional [Dict [str , Any ]] = None ,
941+ get_smac_object_callback : Optional [Callable ] = None ,
942+ all_supported_metrics : bool = True ,
943+ precision : int = 32 , disable_file_output : List = [],
944+ load_models : bool = True ) -> 'BaseTask' :
961945 """
962946 Search for the best pipeline configuration for the given dataset.
963947
@@ -1045,25 +1029,21 @@ def _search(
10451029 self ._search_settings (dataset = dataset , disable_file_output = disable_file_output ,
10461030 optimize_metric = optimize_metric , memory_limit = memory_limit ,
10471031 all_supported_metrics = all_supported_metrics ,
1032+ func_eval_time_limit_secs = func_eval_time_limit_secs ,
10481033 total_walltime_limit = total_walltime_limit )
10491034
1050- func_eval_time_limit_secs = self ._adapt_time_resource_allocation (
1051- total_walltime_limit = total_walltime_limit ,
1052- func_eval_time_limit_secs = func_eval_time_limit_secs
1053- )
1054-
1035+ self ._adapt_time_resource_allocation ()
10551036 self .num_run = 1
10561037 self ._run_dummy_predictions ()
1057- self ._run_traditional_ml (enable_traditional_pipeline = enable_traditional_pipeline ,
1058- func_eval_time_limit_secs = func_eval_time_limit_secs )
1038+
1039+ if not enable_traditional_pipeline :
1040+ self ._run_traditional_ml ()
1041+
10591042 proc_ensemble = self ._run_ensemble (dataset = dataset , precision = precision ,
1060- optimize_metric = optimize_metric ,
1061- total_walltime_limit = total_walltime_limit )
1043+ optimize_metric = optimize_metric )
10621044
10631045 self ._run_smac (budget = budget , budget_type = budget_type , proc_ensemble = proc_ensemble ,
1064- dataset = dataset , total_walltime_limit = total_walltime_limit ,
1065- func_eval_time_limit_secs = func_eval_time_limit_secs ,
1066- get_smac_object_callback = get_smac_object_callback ,
1046+ dataset = dataset , get_smac_object_callback = get_smac_object_callback ,
10671047 smac_scenario_args = smac_scenario_args )
10681048
10691049 self ._finish_experiment (proc_ensemble = proc_ensemble , load_models = load_models )
0 commit comments