4747from autoPyTorch .pipeline .components .training .metrics .base import autoPyTorchMetric
4848from autoPyTorch .pipeline .components .training .metrics .utils import calculate_score , get_metrics
4949from autoPyTorch .utils .backend import Backend , create
50- from autoPyTorch .utils .common import FitRequirement , replace_string_bool_to_bool
50+ from autoPyTorch .utils .common import replace_string_bool_to_bool
5151from autoPyTorch .utils .hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
5252from autoPyTorch .utils .logging_ import (
5353 PicklableClientLogger ,
@@ -170,13 +170,14 @@ def __init__(
170170 os .path .join (os .path .dirname (__file__ ), '../configs/default_pipeline_options.json' ))))
171171
172172 self .search_space : Optional [ConfigurationSpace ] = None
173- self ._dataset_requirements : Optional [List [FitRequirement ]] = None
174173 self ._metric : Optional [autoPyTorchMetric ] = None
175174 self ._logger : Optional [PicklableClientLogger ] = None
176175 self .run_history : Optional [RunHistory ] = None
177176 self .trajectory : Optional [List ] = None
178177 self .dataset_name : Optional [str ] = None
179178 self .cv_models_ : Dict = {}
179+ self .num_run : int = 1
180+ self .experiment_task_name : str = 'runSearch'
180181
181182 # By default try to use the TCP logging port or get a new port
182183 self ._logger_port = logging .handlers .DEFAULT_TCP_LOGGING_PORT
@@ -687,7 +688,7 @@ def _run_dummy_predictions(self) -> None:
687688
688689 def _run_traditional_ml (self ,
689690 enable_traditional_pipeline : bool ,
690- func_eval_time_limit_secs : Optional [int ] = None ) -> int :
691+ func_eval_time_limit_secs : Optional [int ] = None ) -> None :
691692 """We would like to obtain training time for at least 1 Neural network in SMAC"""
692693
693694 if enable_traditional_pipeline :
@@ -784,7 +785,6 @@ def _start_smac(self, proc_smac: AutoMLSMBO):
784785 self ._logger .warning (f"Could not save { trajectory_filename } due to { e } ..." )
785786
786787 def _run_smac (self ,
787- experiment_task_name : str ,
788788 dataset : BaseDataset ,
789789 proc_ensemble : EnsembleBuilderManager ,
790790 total_walltime_limit : int ,
@@ -796,7 +796,7 @@ def _run_smac(self,
796796
797797 smac_task_name = 'runSMAC'
798798 self ._stopwatch .start_task (smac_task_name )
799- elapsed_time = self ._stopwatch .wall_elapsed (experiment_task_name )
799+ elapsed_time = self ._stopwatch .wall_elapsed (self . experiment_task_name )
800800 time_left_for_smac = max (0 , total_walltime_limit - elapsed_time )
801801
802802 self ._logger .info (f"Run SMAC with { time_left_for_smac :.2f} sec time left" )
@@ -831,6 +831,116 @@ def _run_smac(self,
831831
832832 self ._start_smac (proc_smac )
833833
834+ def _search_settings (self , dataset : BaseDataset , disable_file_output : List ,
835+ optimize_metric : str , memory_limit : Optional [int ] = 4096 ,
836+ total_walltime_limit : int = 100 , all_supported_metrics : bool = True
837+ ) -> None :
838+
839+ """Initialise information needed for the experiment"""
840+ self .experiment_task_name = 'runSearch'
841+ dataset_requirements = get_dataset_requirements (
842+ info = self ._get_required_dataset_properties (dataset ))
843+ dataset_properties = dataset .get_dataset_properties (dataset_requirements )
844+
845+ self ._stopwatch .start_task (self .experiment_task_name )
846+ self .dataset_name = dataset .dataset_name
847+ self ._all_supported_metrics = all_supported_metrics
848+ self ._disable_file_output = disable_file_output
849+ self ._memory_limit = memory_limit
850+ self ._time_for_task = total_walltime_limit
851+ self ._metric = get_metrics (
852+ names = [optimize_metric ], dataset_properties = dataset_properties )[0 ]
853+
854+ if self ._logger is None :
855+ self ._logger = self ._get_logger (self .dataset_name )
856+
857+ # Save start time to backend
858+ self ._backend .save_start_time (str (self .seed ))
859+ self ._backend .save_datamanager (dataset )
860+
861+ # Print debug information to log
862+ self ._print_debug_info_to_log ()
863+
864+ self .search_space = self .get_search_space (dataset )
865+
866+ # If no dask client was provided, we create one, so that we can
867+ # start a ensemble process in parallel to smbo optimize
868+ if (
869+ self ._dask_client is None and (self .ensemble_size > 0 or self .n_jobs is not None and self .n_jobs > 1 )
870+ ):
871+ self ._create_dask_client ()
872+ else :
873+ self ._is_dask_client_internally_created = False
874+
875+ def _adapt_time_resource_allocation (self ,
876+ total_walltime_limit : int ,
877+ func_eval_time_limit_secs : Optional [int ] = None
878+ ) -> int :
879+
880+ # Handle time resource allocation
881+ elapsed_time = self ._stopwatch .wall_elapsed (self .experiment_task_name )
882+ time_left_for_modelfit = int (max (0 , total_walltime_limit - elapsed_time ))
883+ if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_left_for_modelfit :
884+ self ._logger .warning (
885+ 'Time limit for a single run is higher than total time '
886+ 'limit. Capping the limit for a single run to the total '
887+ 'time given to SMAC (%f)' % time_left_for_modelfit
888+ )
889+ func_eval_time_limit_secs = time_left_for_modelfit
890+
891+ # Make sure that at least 2 models are created for the ensemble process
892+ num_models = time_left_for_modelfit // func_eval_time_limit_secs
893+ if num_models < 2 :
894+ func_eval_time_limit_secs = time_left_for_modelfit // 2
895+ self ._logger .warning (
896+ "Capping the func_eval_time_limit_secs to {} to have "
897+ "time for a least 2 models to ensemble." .format (
898+ func_eval_time_limit_secs
899+ )
900+ )
901+
902+ return func_eval_time_limit_secs
903+
904+ def _save_ensemble_performance_history (self , proc_ensemble : EnsembleBuilderManager ) -> None :
905+ if len (proc_ensemble .futures ) > 0 :
906+ # Also add ensemble runs that did not finish within smac time
907+ # and add them into the ensemble history
908+ self ._logger .info ("Ensemble script still running, waiting for it to finish." )
909+ result = proc_ensemble .futures .pop ().result ()
910+ if result :
911+ ensemble_history , _ , _ , _ = result
912+ self .ensemble_performance_history .extend (ensemble_history )
913+ self ._logger .info ("Ensemble script finished, continue shutdown." )
914+
915+ # save the ensemble performance history file
916+ if len (self .ensemble_performance_history ) > 0 :
917+ pd .DataFrame (self .ensemble_performance_history ).to_json (
918+ os .path .join (self ._backend .internals_directory , 'ensemble_history.json' ))
919+
920+ def _finish_experiment (self , proc_ensemble : EnsembleBuilderManager ,
921+ load_models : bool ) -> None :
922+
923+ # Wait until the ensemble process is finished to avoid shutting down
924+ # while the ensemble builder tries to access the data
925+ self ._logger .info ("Start Shutdown" )
926+
927+ if proc_ensemble is not None :
928+ self .ensemble_performance_history = list (proc_ensemble .history )
929+ self ._save_ensemble_performance_history (proc_ensemble )
930+
931+ self ._logger .info ("Closing the dask infrastructure" )
932+ self ._close_dask_client ()
933+ self ._logger .info ("Finished closing the dask infrastructure" )
934+
935+ if load_models :
936+ self ._logger .info ("Loading models..." )
937+ self ._load_models ()
938+ self ._logger .info ("Finished loading models..." )
939+
940+ # Clean up the logger
941+ self ._logger .info ("Starting to clean up the logger" )
942+ self ._clean_logger ()
943+
834944 def _search (
835945 self ,
836946 optimize_metric : str ,
@@ -927,69 +1037,20 @@ def _search(
9271037 raise ValueError ("Incompatible dataset entered for current task,"
9281038 "expected dataset to have task type :{} got "
9291039 ":{}" .format (self .task_type , dataset .task_type ))
930- if precision not in [16 , 32 , 64 ]:
931- raise ValueError (f"precision must be either [16, 32, 64], but got { precision } " )
932-
933- # Initialise information needed for the experiment
934- experiment_task_name = 'runSearch'
935- dataset_requirements = get_dataset_requirements (
936- info = self ._get_required_dataset_properties (dataset ))
937- self ._dataset_requirements = dataset_requirements
938- dataset_properties = dataset .get_dataset_properties (dataset_requirements )
939- self ._stopwatch .start_task (experiment_task_name )
940- self .dataset_name = dataset .dataset_name
941- if self ._logger is None :
942- self ._logger = self ._get_logger (self .dataset_name )
943- self ._all_supported_metrics = all_supported_metrics
944- self ._disable_file_output = disable_file_output
945- self ._memory_limit = memory_limit
946- self ._time_for_task = total_walltime_limit
947- # Save start time to backend
948- self ._backend .save_start_time (str (self .seed ))
949-
950- self ._backend .save_datamanager (dataset )
951-
952- # Print debug information to log
953- self ._print_debug_info_to_log ()
954-
955- self ._metric = get_metrics (
956- names = [optimize_metric ], dataset_properties = dataset_properties )[0 ]
957-
958- self .search_space = self .get_search_space (dataset )
959-
9601040 if self .task_type is None :
9611041 raise ValueError ("Cannot interpret task type from the dataset" )
1042+ if precision not in [16 , 32 , 64 ]:
1043+ raise ValueError (f"precision must be either [16, 32, 64], but got { precision } " )
9621044
963- # If no dask client was provided, we create one, so that we can
964- # start a ensemble process in parallel to smbo optimize
965- if (
966- self ._dask_client is None and (self .ensemble_size > 0 or self .n_jobs is not None and self .n_jobs > 1 )
967- ):
968- self ._create_dask_client ()
969- else :
970- self ._is_dask_client_internally_created = False
971-
972- # Handle time resource allocation
973- elapsed_time = self ._stopwatch .wall_elapsed (experiment_task_name )
974- time_left_for_modelfit = int (max (0 , total_walltime_limit - elapsed_time ))
975- if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_left_for_modelfit :
976- self ._logger .warning (
977- 'Time limit for a single run is higher than total time '
978- 'limit. Capping the limit for a single run to the total '
979- 'time given to SMAC (%f)' % time_left_for_modelfit
980- )
981- func_eval_time_limit_secs = time_left_for_modelfit
1045+ self ._search_settings (dataset = dataset , disable_file_output = disable_file_output ,
1046+ optimize_metric = optimize_metric , memory_limit = memory_limit ,
1047+ all_supported_metrics = all_supported_metrics ,
1048+ total_walltime_limit = total_walltime_limit )
9821049
983- # Make sure that at least 2 models are created for the ensemble process
984- num_models = time_left_for_modelfit // func_eval_time_limit_secs
985- if num_models < 2 :
986- func_eval_time_limit_secs = time_left_for_modelfit // 2
987- self ._logger .warning (
988- "Capping the func_eval_time_limit_secs to {} to have "
989- "time for a least 2 models to ensemble." .format (
990- func_eval_time_limit_secs
991- )
992- )
1050+ func_eval_time_limit_secs = self ._adapt_time_resource_allocation (
1051+ total_walltime_limit = total_walltime_limit ,
1052+ func_eval_time_limit_secs = func_eval_time_limit_secs
1053+ )
9931054
9941055 self .num_run = 1
9951056 self ._run_dummy_predictions ()
@@ -999,47 +1060,13 @@ def _search(
9991060 optimize_metric = optimize_metric ,
10001061 total_walltime_limit = total_walltime_limit )
10011062
1002- self ._run_smac (experiment_task_name = experiment_task_name ,
1003- budget = budget , budget_type = budget_type , proc_ensemble = proc_ensemble ,
1063+ self ._run_smac (budget = budget , budget_type = budget_type , proc_ensemble = proc_ensemble ,
10041064 dataset = dataset , total_walltime_limit = total_walltime_limit ,
10051065 func_eval_time_limit_secs = func_eval_time_limit_secs ,
10061066 get_smac_object_callback = get_smac_object_callback ,
10071067 smac_scenario_args = smac_scenario_args )
10081068
1009- # Wait until the ensemble process is finished to avoid shutting down
1010- # while the ensemble builder tries to access the data
1011- self ._logger .info ("Start Shutdown" )
1012-
1013- if proc_ensemble is not None :
1014- self .ensemble_performance_history = list (proc_ensemble .history )
1015-
1016- if len (proc_ensemble .futures ) > 0 :
1017- # Also add ensemble runs that did not finish within smac time
1018- # and add them into the ensemble history
1019- self ._logger .info ("Ensemble script still running, waiting for it to finish." )
1020- result = proc_ensemble .futures .pop ().result ()
1021- if result :
1022- ensemble_history , _ , _ , _ = result
1023- self .ensemble_performance_history .extend (ensemble_history )
1024- self ._logger .info ("Ensemble script finished, continue shutdown." )
1025-
1026- # save the ensemble performance history file
1027- if len (self .ensemble_performance_history ) > 0 :
1028- pd .DataFrame (self .ensemble_performance_history ).to_json (
1029- os .path .join (self ._backend .internals_directory , 'ensemble_history.json' ))
1030-
1031- self ._logger .info ("Closing the dask infrastructure" )
1032- self ._close_dask_client ()
1033- self ._logger .info ("Finished closing the dask infrastructure" )
1034-
1035- if load_models :
1036- self ._logger .info ("Loading models..." )
1037- self ._load_models ()
1038- self ._logger .info ("Finished loading models..." )
1039-
1040- # Clean up the logger
1041- self ._logger .info ("Starting to clean up the logger" )
1042- self ._clean_logger ()
1069+ self ._finish_experiment (proc_ensemble = proc_ensemble , load_models = load_models )
10431070
10441071 return self
10451072
0 commit comments