Skip to content

Commit ddc0f3d

Browse files
authored
Search space update (#80)
* Added Hyperparameter Search space updates * added test for search space update * Added Hyperparameter Search space updates * added test for search space update * Added hyperparameter search space updates to network, trainer and improved check for search space updates * Fix mypy, flake8 * Fix tests and silly mistake in base_pipeline * Fix flake * added _cs_updates to dummy component * fixed indentation and isinstance comment * fixed silly error * Addressed comments from fransisco * added value error for search space updates * ADD tests for setting range of config space * fic utils search space update
1 parent 2e7b462 commit ddc0f3d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+872
-323
lines changed

autoPyTorch/api/base_task.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score, get_metrics
4747
from autoPyTorch.utils.backend import Backend, create
4848
from autoPyTorch.utils.common import FitRequirement, replace_string_bool_to_bool
49+
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
4950
from autoPyTorch.utils.logging_ import (
5051
PicklableClientLogger,
5152
get_named_client_logger,
@@ -135,6 +136,7 @@ def __init__(
135136
include_components: Optional[Dict] = None,
136137
exclude_components: Optional[Dict] = None,
137138
backend: Optional[Backend] = None,
139+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
138140
) -> None:
139141
self.seed = seed
140142
self.n_jobs = n_jobs
@@ -178,6 +180,13 @@ def __init__(
178180

179181
self.stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]
180182

183+
self.search_space_updates = search_space_updates
184+
if search_space_updates is not None:
185+
if not isinstance(self.search_space_updates,
186+
HyperparameterSearchSpaceUpdates):
187+
raise ValueError("Expected search space updates to be of instance"
188+
" HyperparameterSearchSpaceUpdates got {}".format(type(self.search_space_updates)))
189+
181190
@abstractmethod
182191
def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]:
183192
"""
@@ -252,7 +261,8 @@ def get_search_space(self, dataset: BaseDataset = None) -> ConfigurationSpace:
252261
info=self._get_required_dataset_properties(dataset))
253262
return get_configuration_space(info=dataset.get_dataset_properties(dataset_requirements),
254263
include=self.include_components,
255-
exclude=self.exclude_components)
264+
exclude=self.exclude_components,
265+
search_space_updates=self.search_space_updates)
256266
raise Exception("No search space initialised and no dataset passed. "
257267
"Can't create default search space without the dataset")
258268

@@ -816,7 +826,8 @@ def search(
816826
pipeline_config={**self.pipeline_options, **budget_config},
817827
ensemble_callback=proc_ensemble,
818828
logger_port=self._logger_port,
819-
start_num_run=num_run
829+
start_num_run=num_run,
830+
search_space_updates=self.search_space_updates
820831
)
821832
try:
822833
self.run_history, self.trajectory, budget_type = \

autoPyTorch/api/tabular_classification.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from autoPyTorch.datasets.tabular_dataset import TabularDataset
1010
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
1111
from autoPyTorch.utils.backend import Backend
12+
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
1213

1314

1415
class TabularClassificationTask(BaseTask):
@@ -52,6 +53,7 @@ def __init__(
5253
include_components: Optional[Dict] = None,
5354
exclude_components: Optional[Dict] = None,
5455
backend: Optional[Backend] = None,
56+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
5557
):
5658
super().__init__(
5759
seed=seed,
@@ -67,6 +69,7 @@ def __init__(
6769
include_components=include_components,
6870
exclude_components=exclude_components,
6971
backend=backend,
72+
search_space_updates=search_space_updates
7073
)
7174
self.task_type = TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION]
7275

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
get_metrics,
4343
)
4444
from autoPyTorch.utils.backend import Backend
45+
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
4546
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger
4647
from autoPyTorch.utils.pipeline import get_dataset_requirements
4748

@@ -200,7 +201,9 @@ def __init__(self, backend: Backend,
200201
disable_file_output: Union[bool, List[str]] = False,
201202
init_params: Optional[Dict[str, Any]] = None,
202203
logger_port: Optional[int] = None,
203-
all_supported_metrics: bool = True) -> None:
204+
all_supported_metrics: bool = True,
205+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
206+
) -> None:
204207

205208
self.starttime = time.time()
206209

@@ -218,6 +221,7 @@ def __init__(self, backend: Backend,
218221

219222
self.include = include
220223
self.exclude = exclude
224+
self.search_space_updates = search_space_updates
221225

222226
self.X_train, self.y_train = self.datamanager.train_tensors
223227

@@ -324,6 +328,7 @@ def __init__(self, backend: Backend,
324328
self.pipelines: Optional[List[BaseEstimator]] = None
325329
self.pipeline: Optional[BaseEstimator] = None
326330
self.logger.debug("Fit dictionary in Abstract evaluator: {}".format(self.fit_dictionary))
331+
self.logger.debug("Search space updates :{}".format(self.search_space_updates))
327332

328333
def _get_pipeline(self) -> BaseEstimator:
329334
assert self.pipeline_class is not None, "Can't return pipeline, pipeline_class not initialised"
@@ -337,7 +342,8 @@ def _get_pipeline(self) -> BaseEstimator:
337342
random_state=np.random.RandomState(self.seed),
338343
include=self.include,
339344
exclude=self.exclude,
340-
init_params=self._init_params)
345+
init_params=self._init_params,
346+
search_space_updates=self.search_space_updates)
341347
elif isinstance(self.configuration, str):
342348
pipeline = self.pipeline_class(config=self.configuration,
343349
dataset_properties=self.dataset_properties,

autoPyTorch/evaluation/tae.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from autoPyTorch.evaluation.utils import empty_queue, extract_learning_curve, read_queue
2626
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
2727
from autoPyTorch.utils.backend import Backend
28+
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
2829
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger
2930

3031

@@ -111,6 +112,7 @@ def __init__(
111112
ta: typing.Optional[typing.Callable] = None,
112113
logger_port: int = None,
113114
all_supported_metrics: bool = True,
115+
search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None
114116
):
115117

116118
eval_function = autoPyTorch.evaluation.train_evaluator.eval_function
@@ -164,6 +166,8 @@ def __init__(
164166
self.resampling_strategy = dm.resampling_strategy
165167
self.resampling_strategy_args = dm.resampling_strategy_args
166168

169+
self.search_space_updates = search_space_updates
170+
167171
def run_wrapper(
168172
self,
169173
run_info: RunInfo,
@@ -250,6 +254,7 @@ def run(
250254
else:
251255
num_run = config.config_id + self.initial_num_run
252256

257+
self.logger.debug("Search space updates: {}".format(self.search_space_updates))
253258
obj_kwargs = dict(
254259
queue=queue,
255260
config=config,
@@ -267,7 +272,8 @@ def run(
267272
budget_type=self.budget_type,
268273
pipeline_config=self.pipeline_config,
269274
logger_port=self.logger_port,
270-
all_supported_metrics=self.all_supported_metrics
275+
all_supported_metrics=self.all_supported_metrics,
276+
search_space_updates=self.search_space_updates
271277
)
272278

273279
info: typing.Optional[typing.List[RunValue]]

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from autoPyTorch.evaluation.utils import subsampler
2121
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
2222
from autoPyTorch.utils.backend import Backend
23+
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
2324

2425
__all__ = ['TrainEvaluator', 'eval_function']
2526

@@ -48,7 +49,8 @@ def __init__(self, backend: Backend, queue: Queue,
4849
init_params: Optional[Dict[str, Any]] = None,
4950
logger_port: Optional[int] = None,
5051
keep_models: Optional[bool] = None,
51-
all_supported_metrics: bool = True) -> None:
52+
all_supported_metrics: bool = True,
53+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None) -> None:
5254
super().__init__(
5355
backend=backend,
5456
queue=queue,
@@ -65,7 +67,8 @@ def __init__(self, backend: Backend, queue: Queue,
6567
budget_type=budget_type,
6668
logger_port=logger_port,
6769
all_supported_metrics=all_supported_metrics,
68-
pipeline_config=pipeline_config
70+
pipeline_config=pipeline_config,
71+
search_space_updates=search_space_updates
6972
)
7073

7174
self.splits = self.datamanager.splits
@@ -77,6 +80,7 @@ def __init__(self, backend: Backend, queue: Queue,
7780
self.pipelines: List[Optional[BaseEstimator]] = [None] * self.num_folds
7881
self.indices: List[Optional[Tuple[Union[np.ndarray, List], Union[np.ndarray, List]]]] = [None] * self.num_folds
7982

83+
self.logger.debug("Search space updates :{}".format(self.search_space_updates))
8084
self.keep_models = keep_models
8185

8286
def fit_predict_and_loss(self) -> None:
@@ -320,6 +324,7 @@ def eval_function(
320324
init_params: Optional[Dict[str, Any]] = None,
321325
logger_port: Optional[int] = None,
322326
all_supported_metrics: bool = True,
327+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
323328
instance: str = None,
324329
) -> None:
325330
evaluator = TrainEvaluator(
@@ -338,6 +343,7 @@ def eval_function(
338343
budget_type=budget_type,
339344
logger_port=logger_port,
340345
all_supported_metrics=all_supported_metrics,
341-
pipeline_config=pipeline_config
346+
pipeline_config=pipeline_config,
347+
search_space_updates=search_space_updates
342348
)
343349
evaluator.fit_predict_and_loss()

autoPyTorch/optimizer/smbo.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
from smac.tae.serial_runner import SerialRunner
1717
from smac.utils.io.traj_logging import TrajEntry
1818

19-
# TODO: Enable when merged Ensemble
20-
# from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
2119
from autoPyTorch.datasets.base_dataset import BaseDataset
2220
from autoPyTorch.datasets.resampling_strategy import (
2321
CrossValTypes,
2422
DEFAULT_RESAMPLING_PARAMETERS,
2523
HoldoutValTypes,
2624
)
25+
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
2726
from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash
2827
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
2928
from autoPyTorch.utils.backend import Backend
29+
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
3030
from autoPyTorch.utils.logging_ import get_named_client_logger
3131
from autoPyTorch.utils.stopwatch import StopWatch
3232

@@ -101,10 +101,9 @@ def __init__(self,
101101
smac_scenario_args: typing.Optional[typing.Dict[str, typing.Any]] = None,
102102
get_smac_object_callback: typing.Optional[typing.Callable] = None,
103103
all_supported_metrics: bool = True,
104-
# TODO: Re-enable when ensemble merged
105-
# ensemble_callback: typing.Optional[EnsembleBuilderManager] = None,
106-
ensemble_callback: typing.Any = None,
104+
ensemble_callback: typing.Optional[EnsembleBuilderManager] = None,
107105
logger_port: typing.Optional[int] = None,
106+
search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None
108107
):
109108
"""
110109
Interface to SMAC. This method calls the SMAC optimize method, and allows
@@ -194,6 +193,8 @@ def __init__(self,
194193

195194
self.ensemble_callback = ensemble_callback
196195

196+
self.search_space_updates = search_space_updates
197+
197198
dataset_name_ = "" if dataset_name is None else dataset_name
198199
if logger_port is None:
199200
self.logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
@@ -254,7 +255,8 @@ def run_smbo(self, func: typing.Optional[typing.Callable] = None
254255
ta=func,
255256
logger_port=self.logger_port,
256257
all_supported_metrics=self.all_supported_metrics,
257-
pipeline_config=self.pipeline_config
258+
pipeline_config=self.pipeline_config,
259+
search_space_updates=self.search_space_updates
258260
)
259261
ta = ExecuteTaFuncWithQueue
260262
self.logger.info("Created TA")

0 commit comments

Comments
 (0)