Skip to content

Commit ff1d7d3

Browse files
committed
[refactor] Add _search_settings to make _search() function shorter
1 parent 7186348 commit ff1d7d3

File tree

1 file changed

+127
-100
lines changed

1 file changed

+127
-100
lines changed

autoPyTorch/api/base_task.py

Lines changed: 127 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
4848
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score, get_metrics
4949
from autoPyTorch.utils.backend import Backend, create
50-
from autoPyTorch.utils.common import FitRequirement, replace_string_bool_to_bool
50+
from autoPyTorch.utils.common import replace_string_bool_to_bool
5151
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
5252
from autoPyTorch.utils.logging_ import (
5353
PicklableClientLogger,
@@ -170,13 +170,14 @@ def __init__(
170170
os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json'))))
171171

172172
self.search_space: Optional[ConfigurationSpace] = None
173-
self._dataset_requirements: Optional[List[FitRequirement]] = None
174173
self._metric: Optional[autoPyTorchMetric] = None
175174
self._logger: Optional[PicklableClientLogger] = None
176175
self.run_history: Optional[RunHistory] = None
177176
self.trajectory: Optional[List] = None
178177
self.dataset_name: Optional[str] = None
179178
self.cv_models_: Dict = {}
179+
self.num_run: int = 1
180+
self.experiment_task_name: str = 'runSearch'
180181

181182
# By default try to use the TCP logging port or get a new port
182183
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
@@ -687,7 +688,7 @@ def _run_dummy_predictions(self) -> None:
687688

688689
def _run_traditional_ml(self,
689690
enable_traditional_pipeline: bool,
690-
func_eval_time_limit_secs: Optional[int] = None) -> int:
691+
func_eval_time_limit_secs: Optional[int] = None) -> None:
691692
"""We would like to obtain training time for at least 1 Neural network in SMAC"""
692693

693694
if enable_traditional_pipeline:
@@ -784,7 +785,6 @@ def _start_smac(self, proc_smac: AutoMLSMBO):
784785
self._logger.warning(f"Could not save {trajectory_filename} due to {e}...")
785786

786787
def _run_smac(self,
787-
experiment_task_name: str,
788788
dataset: BaseDataset,
789789
proc_ensemble: EnsembleBuilderManager,
790790
total_walltime_limit: int,
@@ -796,7 +796,7 @@ def _run_smac(self,
796796

797797
smac_task_name = 'runSMAC'
798798
self._stopwatch.start_task(smac_task_name)
799-
elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name)
799+
elapsed_time = self._stopwatch.wall_elapsed(self.experiment_task_name)
800800
time_left_for_smac = max(0, total_walltime_limit - elapsed_time)
801801

802802
self._logger.info(f"Run SMAC with {time_left_for_smac:.2f} sec time left")
@@ -831,6 +831,116 @@ def _run_smac(self,
831831

832832
self._start_smac(proc_smac)
833833

834+
def _search_settings(self, dataset: BaseDataset, disable_file_output: List,
835+
optimize_metric: str, memory_limit: Optional[int] = 4096,
836+
total_walltime_limit: int = 100, all_supported_metrics: bool = True
837+
) -> None:
838+
839+
"""Initialise information needed for the experiment"""
840+
self.experiment_task_name = 'runSearch'
841+
dataset_requirements = get_dataset_requirements(
842+
info=self._get_required_dataset_properties(dataset))
843+
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
844+
845+
self._stopwatch.start_task(self.experiment_task_name)
846+
self.dataset_name = dataset.dataset_name
847+
self._all_supported_metrics = all_supported_metrics
848+
self._disable_file_output = disable_file_output
849+
self._memory_limit = memory_limit
850+
self._time_for_task = total_walltime_limit
851+
self._metric = get_metrics(
852+
names=[optimize_metric], dataset_properties=dataset_properties)[0]
853+
854+
if self._logger is None:
855+
self._logger = self._get_logger(self.dataset_name)
856+
857+
# Save start time to backend
858+
self._backend.save_start_time(str(self.seed))
859+
self._backend.save_datamanager(dataset)
860+
861+
# Print debug information to log
862+
self._print_debug_info_to_log()
863+
864+
self.search_space = self.get_search_space(dataset)
865+
866+
# If no dask client was provided, we create one, so that we can
867+
# start a ensemble process in parallel to smbo optimize
868+
if (
869+
self._dask_client is None and (self.ensemble_size > 0 or self.n_jobs is not None and self.n_jobs > 1)
870+
):
871+
self._create_dask_client()
872+
else:
873+
self._is_dask_client_internally_created = False
874+
875+
def _adapt_time_resource_allocation(self,
876+
total_walltime_limit: int,
877+
func_eval_time_limit_secs: Optional[int] = None
878+
) -> int:
879+
880+
# Handle time resource allocation
881+
elapsed_time = self._stopwatch.wall_elapsed(self.experiment_task_name)
882+
time_left_for_modelfit = int(max(0, total_walltime_limit - elapsed_time))
883+
if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_left_for_modelfit:
884+
self._logger.warning(
885+
'Time limit for a single run is higher than total time '
886+
'limit. Capping the limit for a single run to the total '
887+
'time given to SMAC (%f)' % time_left_for_modelfit
888+
)
889+
func_eval_time_limit_secs = time_left_for_modelfit
890+
891+
# Make sure that at least 2 models are created for the ensemble process
892+
num_models = time_left_for_modelfit // func_eval_time_limit_secs
893+
if num_models < 2:
894+
func_eval_time_limit_secs = time_left_for_modelfit // 2
895+
self._logger.warning(
896+
"Capping the func_eval_time_limit_secs to {} to have "
897+
"time for a least 2 models to ensemble.".format(
898+
func_eval_time_limit_secs
899+
)
900+
)
901+
902+
return func_eval_time_limit_secs
903+
904+
def _save_ensemble_performance_history(self, proc_ensemble: EnsembleBuilderManager) -> None:
905+
if len(proc_ensemble.futures) > 0:
906+
# Also add ensemble runs that did not finish within smac time
907+
# and add them into the ensemble history
908+
self._logger.info("Ensemble script still running, waiting for it to finish.")
909+
result = proc_ensemble.futures.pop().result()
910+
if result:
911+
ensemble_history, _, _, _ = result
912+
self.ensemble_performance_history.extend(ensemble_history)
913+
self._logger.info("Ensemble script finished, continue shutdown.")
914+
915+
# save the ensemble performance history file
916+
if len(self.ensemble_performance_history) > 0:
917+
pd.DataFrame(self.ensemble_performance_history).to_json(
918+
os.path.join(self._backend.internals_directory, 'ensemble_history.json'))
919+
920+
def _finish_experiment(self, proc_ensemble: EnsembleBuilderManager,
921+
load_models: bool) -> None:
922+
923+
# Wait until the ensemble process is finished to avoid shutting down
924+
# while the ensemble builder tries to access the data
925+
self._logger.info("Start Shutdown")
926+
927+
if proc_ensemble is not None:
928+
self.ensemble_performance_history = list(proc_ensemble.history)
929+
self._save_ensemble_performance_history(proc_ensemble)
930+
931+
self._logger.info("Closing the dask infrastructure")
932+
self._close_dask_client()
933+
self._logger.info("Finished closing the dask infrastructure")
934+
935+
if load_models:
936+
self._logger.info("Loading models...")
937+
self._load_models()
938+
self._logger.info("Finished loading models...")
939+
940+
# Clean up the logger
941+
self._logger.info("Starting to clean up the logger")
942+
self._clean_logger()
943+
834944
def _search(
835945
self,
836946
optimize_metric: str,
@@ -927,69 +1037,20 @@ def _search(
9271037
raise ValueError("Incompatible dataset entered for current task,"
9281038
"expected dataset to have task type :{} got "
9291039
":{}".format(self.task_type, dataset.task_type))
930-
if precision not in [16, 32, 64]:
931-
raise ValueError(f"precision must be either [16, 32, 64], but got {precision}")
932-
933-
# Initialise information needed for the experiment
934-
experiment_task_name = 'runSearch'
935-
dataset_requirements = get_dataset_requirements(
936-
info=self._get_required_dataset_properties(dataset))
937-
self._dataset_requirements = dataset_requirements
938-
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
939-
self._stopwatch.start_task(experiment_task_name)
940-
self.dataset_name = dataset.dataset_name
941-
if self._logger is None:
942-
self._logger = self._get_logger(self.dataset_name)
943-
self._all_supported_metrics = all_supported_metrics
944-
self._disable_file_output = disable_file_output
945-
self._memory_limit = memory_limit
946-
self._time_for_task = total_walltime_limit
947-
# Save start time to backend
948-
self._backend.save_start_time(str(self.seed))
949-
950-
self._backend.save_datamanager(dataset)
951-
952-
# Print debug information to log
953-
self._print_debug_info_to_log()
954-
955-
self._metric = get_metrics(
956-
names=[optimize_metric], dataset_properties=dataset_properties)[0]
957-
958-
self.search_space = self.get_search_space(dataset)
959-
9601040
if self.task_type is None:
9611041
raise ValueError("Cannot interpret task type from the dataset")
1042+
if precision not in [16, 32, 64]:
1043+
raise ValueError(f"precision must be either [16, 32, 64], but got {precision}")
9621044

963-
# If no dask client was provided, we create one, so that we can
964-
# start a ensemble process in parallel to smbo optimize
965-
if (
966-
self._dask_client is None and (self.ensemble_size > 0 or self.n_jobs is not None and self.n_jobs > 1)
967-
):
968-
self._create_dask_client()
969-
else:
970-
self._is_dask_client_internally_created = False
971-
972-
# Handle time resource allocation
973-
elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name)
974-
time_left_for_modelfit = int(max(0, total_walltime_limit - elapsed_time))
975-
if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_left_for_modelfit:
976-
self._logger.warning(
977-
'Time limit for a single run is higher than total time '
978-
'limit. Capping the limit for a single run to the total '
979-
'time given to SMAC (%f)' % time_left_for_modelfit
980-
)
981-
func_eval_time_limit_secs = time_left_for_modelfit
1045+
self._search_settings(dataset=dataset, disable_file_output=disable_file_output,
1046+
optimize_metric=optimize_metric, memory_limit=memory_limit,
1047+
all_supported_metrics=all_supported_metrics,
1048+
total_walltime_limit=total_walltime_limit)
9821049

983-
# Make sure that at least 2 models are created for the ensemble process
984-
num_models = time_left_for_modelfit // func_eval_time_limit_secs
985-
if num_models < 2:
986-
func_eval_time_limit_secs = time_left_for_modelfit // 2
987-
self._logger.warning(
988-
"Capping the func_eval_time_limit_secs to {} to have "
989-
"time for a least 2 models to ensemble.".format(
990-
func_eval_time_limit_secs
991-
)
992-
)
1050+
func_eval_time_limit_secs = self._adapt_time_resource_allocation(
1051+
total_walltime_limit=total_walltime_limit,
1052+
func_eval_time_limit_secs=func_eval_time_limit_secs
1053+
)
9931054

9941055
self.num_run = 1
9951056
self._run_dummy_predictions()
@@ -999,47 +1060,13 @@ def _search(
9991060
optimize_metric=optimize_metric,
10001061
total_walltime_limit=total_walltime_limit)
10011062

1002-
self._run_smac(experiment_task_name=experiment_task_name,
1003-
budget=budget, budget_type=budget_type, proc_ensemble=proc_ensemble,
1063+
self._run_smac(budget=budget, budget_type=budget_type, proc_ensemble=proc_ensemble,
10041064
dataset=dataset, total_walltime_limit=total_walltime_limit,
10051065
func_eval_time_limit_secs=func_eval_time_limit_secs,
10061066
get_smac_object_callback=get_smac_object_callback,
10071067
smac_scenario_args=smac_scenario_args)
10081068

1009-
# Wait until the ensemble process is finished to avoid shutting down
1010-
# while the ensemble builder tries to access the data
1011-
self._logger.info("Start Shutdown")
1012-
1013-
if proc_ensemble is not None:
1014-
self.ensemble_performance_history = list(proc_ensemble.history)
1015-
1016-
if len(proc_ensemble.futures) > 0:
1017-
# Also add ensemble runs that did not finish within smac time
1018-
# and add them into the ensemble history
1019-
self._logger.info("Ensemble script still running, waiting for it to finish.")
1020-
result = proc_ensemble.futures.pop().result()
1021-
if result:
1022-
ensemble_history, _, _, _ = result
1023-
self.ensemble_performance_history.extend(ensemble_history)
1024-
self._logger.info("Ensemble script finished, continue shutdown.")
1025-
1026-
# save the ensemble performance history file
1027-
if len(self.ensemble_performance_history) > 0:
1028-
pd.DataFrame(self.ensemble_performance_history).to_json(
1029-
os.path.join(self._backend.internals_directory, 'ensemble_history.json'))
1030-
1031-
self._logger.info("Closing the dask infrastructure")
1032-
self._close_dask_client()
1033-
self._logger.info("Finished closing the dask infrastructure")
1034-
1035-
if load_models:
1036-
self._logger.info("Loading models...")
1037-
self._load_models()
1038-
self._logger.info("Finished loading models...")
1039-
1040-
# Clean up the logger
1041-
self._logger.info("Starting to clean up the logger")
1042-
self._clean_logger()
1069+
self._finish_experiment(proc_ensemble=proc_ensemble, load_models=load_models)
10431070

10441071
return self
10451072

0 commit comments

Comments
 (0)