Skip to content

Commit ec1a508

Browse files
committed
[fix] Resolve all the mypy issues
1 parent ff1d7d3 commit ec1a508

File tree

1 file changed

+82
-102
lines changed

1 file changed

+82
-102
lines changed

autoPyTorch/api/base_task.py

Lines changed: 82 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def _do_dummy_prediction(self) -> None:
499499
memory_limit = int(math.ceil(memory_limit))
500500

501501
scenario_mock = unittest.mock.Mock()
502-
scenario_mock.wallclock_limit = self._time_for_task
502+
scenario_mock.wallclock_limit = self._total_walltime_limit
503503
# This stats object is a hack - maybe the SMAC stats object should
504504
# already be generated here!
505505
stats = Stats(scenario_mock)
@@ -518,7 +518,7 @@ def _do_dummy_prediction(self) -> None:
518518
all_supported_metrics=self._all_supported_metrics
519519
)
520520

521-
status, cost, runtime, additional_info = ta.run(self.num_run, cutoff=self._time_for_task)
521+
status, cost, runtime, additional_info = ta.run(self.num_run, cutoff=self._total_walltime_limit)
522522
if status == StatusType.SUCCESS:
523523
self._logger.info("Finished creating dummy predictions.")
524524
else:
@@ -552,8 +552,7 @@ def _do_dummy_prediction(self) -> None:
552552
% (str(status), str(additional_info))
553553
)
554554

555-
def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: int
556-
) -> None:
555+
def _do_traditional_prediction(self, time_left: int) -> None:
557556
"""
558557
Fits traditional machine learning algorithms to the provided dataset, while
559558
complying with time resource allocation.
@@ -596,8 +595,8 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
596595

597596
# Only launch a task if there is time
598597
start_time = time.time()
599-
if time_left >= func_eval_time_limit_secs:
600-
self._logger.info(f"{n_r}: Started fitting {classifier} with cutoff={func_eval_time_limit_secs}")
598+
if time_left >= self._func_eval_time_limit_secs:
599+
self._logger.info(f"{n_r}: Started fitting {classifier} with cutoff={self._func_eval_time_limit_secs}")
601600
scenario_mock = unittest.mock.Mock()
602601
scenario_mock.wallclock_limit = time_left
603602
# This stats object is a hack - maybe the SMAC stats object should
@@ -621,7 +620,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
621620
classifier,
622621
self._dask_client.submit(
623622
ta.run, config=classifier,
624-
cutoff=func_eval_time_limit_secs,
623+
cutoff=self._func_eval_time_limit_secs,
625624
)
626625
])
627626

@@ -640,7 +639,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
640639

641640
# How many workers to wait before starting fitting the next iteration
642641
workers_to_wait = 1
643-
if n_r >= total_number_classifiers - 1 or time_left <= func_eval_time_limit_secs:
642+
if n_r >= total_number_classifiers - 1 or time_left <= self._func_eval_time_limit_secs:
644643
# If on the last iteration, flush out all tasks
645644
workers_to_wait = len(dask_futures)
646645

@@ -675,7 +674,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
675674
time_left -= int(time.time() - start_time)
676675

677676
# Exit if no more time is available for a new classifier
678-
if time_left < func_eval_time_limit_secs:
677+
if time_left < self._func_eval_time_limit_secs:
679678
self._logger.warning("Not enough time to fit all traditional machine learning models."
680679
"Please consider increasing the run time to further improve performance.")
681680
break
@@ -686,36 +685,30 @@ def _run_dummy_predictions(self) -> None:
686685
self._do_dummy_prediction()
687686
self._stopwatch.stop_task(dummy_task_name)
688687

689-
def _run_traditional_ml(self,
690-
enable_traditional_pipeline: bool,
691-
func_eval_time_limit_secs: Optional[int] = None) -> None:
688+
def _run_traditional_ml(self) -> None:
692689
"""We would like to obtain training time for at least 1 Neural network in SMAC"""
690+
assert self._logger is not None
693691

694-
if enable_traditional_pipeline:
695-
if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS:
696-
self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...")
697-
else:
698-
traditional_task_name = 'runTraditional'
699-
self._stopwatch.start_task(traditional_task_name)
700-
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
692+
if STRING_TO_TASK_TYPES[self.task_type] in REGRESSION_TASKS:
693+
self._logger.warning("Traditional Pipeline is not enabled for regression. Skipping...")
694+
else:
695+
traditional_task_name = 'runTraditional'
696+
self._stopwatch.start_task(traditional_task_name)
697+
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
701698

702-
time_for_traditional = int(
703-
self._time_for_task - elapsed_time - func_eval_time_limit_secs
704-
)
705-
self.num_run = self._do_traditional_prediction(
706-
func_eval_time_limit_secs=func_eval_time_limit_secs,
707-
time_left=time_for_traditional,
708-
)
709-
self._stopwatch.stop_task(traditional_task_name)
699+
time_for_traditional = int(
700+
self._total_walltime_limit - elapsed_time - self._func_eval_time_limit_secs
701+
)
702+
self._do_traditional_prediction(time_left=time_for_traditional)
703+
self._stopwatch.stop_task(traditional_task_name)
710704

711-
def _run_ensemble(self,
712-
dataset: BaseDataset,
713-
optimize_metric: str,
714-
total_walltime_limit: int,
705+
def _run_ensemble(self, dataset: BaseDataset, optimize_metric: str,
715706
precision: int) -> EnsembleBuilderManager:
716707

708+
assert self._logger is not None
709+
717710
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
718-
time_left_for_ensembles = max(0, total_walltime_limit - elapsed_time)
711+
time_left_for_ensembles = max(0, self._total_walltime_limit - elapsed_time)
719712
proc_ensemble = None
720713
if time_left_for_ensembles <= 0 and self.ensemble_size > 0:
721714
raise ValueError("Could not run ensemble builder because there "
@@ -734,25 +727,20 @@ def _run_ensemble(self,
734727
dataset_name=dataset.dataset_name,
735728
output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type],
736729
task_type=STRING_TO_TASK_TYPES[self.task_type],
737-
metrics=[self._metric],
738-
opt_metric=optimize_metric,
730+
metrics=[self._metric], opt_metric=optimize_metric,
739731
ensemble_size=self.ensemble_size,
740732
ensemble_nbest=self.ensemble_nbest,
741733
max_models_on_disc=self.max_models_on_disc,
742-
seed=self.seed,
743-
max_iterations=None,
744-
read_at_most=sys.maxsize,
745734
ensemble_memory_limit=self._memory_limit,
746-
random_state=self.seed,
747-
precision=precision,
748-
logger_port=self._logger_port,
735+
seed=self.seed, max_iterations=None, random_state=self.seed,
736+
read_at_most=sys.maxsize, precision=precision,
737+
logger_port=self._logger_port
749738
)
750739
self._stopwatch.stop_task(ensemble_task_name)
751740

752741
return proc_ensemble
753742

754-
def _get_budget_config(self,
755-
budget_type: Optional[str] = None,
743+
def _get_budget_config(self, budget_type: Optional[str] = None,
756744
budget: Optional[float] = None) -> Dict[str, Union[float, str]]:
757745

758746
budget_config: Dict[str, Union[float, str]] = {}
@@ -764,13 +752,18 @@ def _get_budget_config(self,
764752

765753
return budget_config
766754

767-
def _start_smac(self, proc_smac: AutoMLSMBO):
755+
def _start_smac(self, proc_smac: AutoMLSMBO) -> None:
756+
assert self._logger is not None
757+
768758
try:
769759
self.run_history, self.trajectory, budget_type = \
770760
proc_smac.run_smbo()
771761
trajectory_filename = os.path.join(
772762
self._backend.get_smac_output_directory_for_run(self.seed),
773763
'trajectory.json')
764+
765+
assert self.trajectory is not None
766+
774767
saveable_trajectory = \
775768
[list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:])
776769
for entry in self.trajectory]
@@ -784,20 +777,17 @@ def _start_smac(self, proc_smac: AutoMLSMBO):
784777
except Exception as e:
785778
self._logger.warning(f"Could not save {trajectory_filename} due to {e}...")
786779

787-
def _run_smac(self,
788-
dataset: BaseDataset,
789-
proc_ensemble: EnsembleBuilderManager,
790-
total_walltime_limit: int,
791-
budget_type: Optional[str] = None,
792-
budget: Optional[float] = None,
793-
func_eval_time_limit_secs: Optional[int] = None,
780+
def _run_smac(self, dataset: BaseDataset, proc_ensemble: EnsembleBuilderManager,
781+
budget_type: Optional[str] = None, budget: Optional[float] = None,
794782
get_smac_object_callback: Optional[Callable] = None,
795783
smac_scenario_args: Optional[Dict[str, Any]] = None) -> None:
796784

785+
assert self._logger is not None
786+
797787
smac_task_name = 'runSMAC'
798788
self._stopwatch.start_task(smac_task_name)
799789
elapsed_time = self._stopwatch.wall_elapsed(self.experiment_task_name)
800-
time_left_for_smac = max(0, total_walltime_limit - elapsed_time)
790+
time_left_for_smac = max(0, self._total_walltime_limit - elapsed_time)
801791

802792
self._logger.info(f"Run SMAC with {time_left_for_smac:.2f} sec time left")
803793
if time_left_for_smac <= 0:
@@ -808,14 +798,12 @@ def _run_smac(self,
808798
config_space=self.search_space,
809799
dataset_name=dataset.dataset_name,
810800
backend=self._backend,
811-
total_walltime_limit=total_walltime_limit,
812-
func_eval_time_limit_secs=func_eval_time_limit_secs,
801+
total_walltime_limit=self._total_walltime_limit,
802+
func_eval_time_limit_secs=self._func_eval_time_limit_secs,
813803
dask_client=self._dask_client,
814804
memory_limit=self._memory_limit,
815-
n_jobs=self.n_jobs,
816-
watcher=self._stopwatch,
817-
metric=self._metric,
818-
seed=self.seed,
805+
n_jobs=self.n_jobs, watcher=self._stopwatch,
806+
metric=self._metric, seed=self.seed,
819807
include=self.include_components,
820808
exclude=self.exclude_components,
821809
disable_file_output=self._disable_file_output,
@@ -833,8 +821,9 @@ def _run_smac(self,
833821

834822
def _search_settings(self, dataset: BaseDataset, disable_file_output: List,
835823
optimize_metric: str, memory_limit: Optional[int] = 4096,
836-
total_walltime_limit: int = 100, all_supported_metrics: bool = True
837-
) -> None:
824+
func_eval_time_limit_secs: Optional[int] = None,
825+
total_walltime_limit: int = 100,
826+
all_supported_metrics: bool = True) -> None:
838827

839828
"""Initialise information needed for the experiment"""
840829
self.experiment_task_name = 'runSearch'
@@ -847,12 +836,13 @@ def _search_settings(self, dataset: BaseDataset, disable_file_output: List,
847836
self._all_supported_metrics = all_supported_metrics
848837
self._disable_file_output = disable_file_output
849838
self._memory_limit = memory_limit
850-
self._time_for_task = total_walltime_limit
839+
self._total_walltime_limit = total_walltime_limit
840+
self._func_eval_time_limit_secs = func_eval_time_limit_secs
851841
self._metric = get_metrics(
852842
names=[optimize_metric], dataset_properties=dataset_properties)[0]
853843

854844
if self._logger is None:
855-
self._logger = self._get_logger(self.dataset_name)
845+
self._logger = self._get_logger(str(self.dataset_name))
856846

857847
# Save start time to backend
858848
self._backend.save_start_time(str(self.seed))
@@ -872,36 +862,34 @@ def _search_settings(self, dataset: BaseDataset, disable_file_output: List,
872862
else:
873863
self._is_dask_client_internally_created = False
874864

875-
def _adapt_time_resource_allocation(self,
876-
total_walltime_limit: int,
877-
func_eval_time_limit_secs: Optional[int] = None
878-
) -> int:
865+
def _adapt_time_resource_allocation(self) -> None:
866+
assert self._logger is not None
879867

880868
# Handle time resource allocation
881869
elapsed_time = self._stopwatch.wall_elapsed(self.experiment_task_name)
882-
time_left_for_modelfit = int(max(0, total_walltime_limit - elapsed_time))
883-
if func_eval_time_limit_secs is None or func_eval_time_limit_secs > time_left_for_modelfit:
870+
time_left_for_modelfit = int(max(0, self._total_walltime_limit - elapsed_time))
871+
if self._func_eval_time_limit_secs is None or self._func_eval_time_limit_secs > time_left_for_modelfit:
884872
self._logger.warning(
885873
'Time limit for a single run is higher than total time '
886874
'limit. Capping the limit for a single run to the total '
887875
'time given to SMAC (%f)' % time_left_for_modelfit
888876
)
889-
func_eval_time_limit_secs = time_left_for_modelfit
877+
self._func_eval_time_limit_secs = time_left_for_modelfit
890878

891879
# Make sure that at least 2 models are created for the ensemble process
892-
num_models = time_left_for_modelfit // func_eval_time_limit_secs
880+
num_models = time_left_for_modelfit // self._func_eval_time_limit_secs
893881
if num_models < 2:
894-
func_eval_time_limit_secs = time_left_for_modelfit // 2
882+
self._func_eval_time_limit_secs = time_left_for_modelfit // 2
895883
self._logger.warning(
896884
"Capping the func_eval_time_limit_secs to {} to have "
897885
"time for a least 2 models to ensemble.".format(
898-
func_eval_time_limit_secs
886+
self._func_eval_time_limit_secs
899887
)
900888
)
901889

902-
return func_eval_time_limit_secs
903-
904890
def _save_ensemble_performance_history(self, proc_ensemble: EnsembleBuilderManager) -> None:
891+
assert self._logger is not None
892+
905893
if len(proc_ensemble.futures) > 0:
906894
# Also add ensemble runs that did not finish within smac time
907895
# and add them into the ensemble history
@@ -920,6 +908,7 @@ def _save_ensemble_performance_history(self, proc_ensemble: EnsembleBuilderManag
920908
def _finish_experiment(self, proc_ensemble: EnsembleBuilderManager,
921909
load_models: bool) -> None:
922910

911+
assert self._logger is not None
923912
# Wait until the ensemble process is finished to avoid shutting down
924913
# while the ensemble builder tries to access the data
925914
self._logger.info("Start Shutdown")
@@ -941,23 +930,18 @@ def _finish_experiment(self, proc_ensemble: EnsembleBuilderManager,
941930
self._logger.info("Starting to clean up the logger")
942931
self._clean_logger()
943932

944-
def _search(
945-
self,
946-
optimize_metric: str,
947-
dataset: BaseDataset,
948-
budget_type: Optional[str] = None,
949-
budget: Optional[float] = None,
950-
total_walltime_limit: int = 100,
951-
func_eval_time_limit_secs: Optional[int] = None,
952-
enable_traditional_pipeline: bool = True,
953-
memory_limit: Optional[int] = 4096,
954-
smac_scenario_args: Optional[Dict[str, Any]] = None,
955-
get_smac_object_callback: Optional[Callable] = None,
956-
all_supported_metrics: bool = True,
957-
precision: int = 32,
958-
disable_file_output: List = [],
959-
load_models: bool = True,
960-
) -> 'BaseTask':
933+
def _search(self, optimize_metric: str,
934+
dataset: BaseDataset, budget_type: Optional[str] = None,
935+
budget: Optional[float] = None,
936+
total_walltime_limit: int = 100,
937+
func_eval_time_limit_secs: Optional[int] = None,
938+
enable_traditional_pipeline: bool = True,
939+
memory_limit: Optional[int] = 4096,
940+
smac_scenario_args: Optional[Dict[str, Any]] = None,
941+
get_smac_object_callback: Optional[Callable] = None,
942+
all_supported_metrics: bool = True,
943+
precision: int = 32, disable_file_output: List = [],
944+
load_models: bool = True) -> 'BaseTask':
961945
"""
962946
Search for the best pipeline configuration for the given dataset.
963947
@@ -1045,25 +1029,21 @@ def _search(
10451029
self._search_settings(dataset=dataset, disable_file_output=disable_file_output,
10461030
optimize_metric=optimize_metric, memory_limit=memory_limit,
10471031
all_supported_metrics=all_supported_metrics,
1032+
func_eval_time_limit_secs=func_eval_time_limit_secs,
10481033
total_walltime_limit=total_walltime_limit)
10491034

1050-
func_eval_time_limit_secs = self._adapt_time_resource_allocation(
1051-
total_walltime_limit=total_walltime_limit,
1052-
func_eval_time_limit_secs=func_eval_time_limit_secs
1053-
)
1054-
1035+
self._adapt_time_resource_allocation()
10551036
self.num_run = 1
10561037
self._run_dummy_predictions()
1057-
self._run_traditional_ml(enable_traditional_pipeline=enable_traditional_pipeline,
1058-
func_eval_time_limit_secs=func_eval_time_limit_secs)
1038+
1039+
if not enable_traditional_pipeline:
1040+
self._run_traditional_ml()
1041+
10591042
proc_ensemble = self._run_ensemble(dataset=dataset, precision=precision,
1060-
optimize_metric=optimize_metric,
1061-
total_walltime_limit=total_walltime_limit)
1043+
optimize_metric=optimize_metric)
10621044

10631045
self._run_smac(budget=budget, budget_type=budget_type, proc_ensemble=proc_ensemble,
1064-
dataset=dataset, total_walltime_limit=total_walltime_limit,
1065-
func_eval_time_limit_secs=func_eval_time_limit_secs,
1066-
get_smac_object_callback=get_smac_object_callback,
1046+
dataset=dataset, get_smac_object_callback=get_smac_object_callback,
10671047
smac_scenario_args=smac_scenario_args)
10681048

10691049
self._finish_experiment(proc_ensemble=proc_ensemble, load_models=load_models)

0 commit comments

Comments
 (0)