Skip to content

Commit b337969

Browse files
committed
rebase and fix flake
1 parent b73c992 commit b337969

File tree

5 files changed

+93
-20
lines changed

5 files changed

+93
-20
lines changed

autoPyTorch/api/base_task.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,14 @@ def _get_dataset_input_validator(
299299
y_train: Union[List, pd.DataFrame, np.ndarray],
300300
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
301301
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
302+
<<<<<<< HEAD
302303
resampling_strategy: Optional[ResamplingStrategies] = None,
304+
=======
305+
resampling_strategy: Optional[Union[
306+
CrossValTypes,
307+
HoldoutValTypes,
308+
NoResamplingStrategyTypes]] = None,
309+
>>>>>>> rebase and fix flake
303310
resampling_strategy_args: Optional[Dict[str, Any]] = None,
304311
dataset_name: Optional[str] = None,
305312
) -> Tuple[BaseDataset, BaseInputValidator]:
@@ -341,7 +348,14 @@ def get_dataset(
341348
y_train: Union[List, pd.DataFrame, np.ndarray],
342349
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
343350
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
351+
<<<<<<< HEAD
344352
resampling_strategy: Optional[ResamplingStrategies] = None,
353+
=======
354+
resampling_strategy: Optional[Union[
355+
CrossValTypes,
356+
HoldoutValTypes,
357+
NoResamplingStrategyTypes]] = None,
358+
>>>>>>> rebase and fix flake
345359
resampling_strategy_args: Optional[Dict[str, Any]] = None,
346360
dataset_name: Optional[str] = None,
347361
) -> BaseDataset:
@@ -1391,7 +1405,14 @@ def fit_pipeline(
13911405
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13921406
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
13931407
dataset_name: Optional[str] = None,
1408+
<<<<<<< HEAD
13941409
resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes, NoResamplingStrategyTypes]] = None,
1410+
=======
1411+
resampling_strategy: Optional[Union[
1412+
CrossValTypes,
1413+
HoldoutValTypes,
1414+
NoResamplingStrategyTypes]] = None,
1415+
>>>>>>> rebase and fix flake
13951416
resampling_strategy_args: Optional[Dict[str, Any]] = None,
13961417
run_time_limit_secs: int = 60,
13971418
memory_limit: Optional[int] = None,
@@ -1515,7 +1536,6 @@ def fit_pipeline(
15151536
(BaseDataset):
15161537
Dataset created from the given tensors
15171538
"""
1518-
self.dataset_name = dataset.dataset_name
15191539

15201540
if dataset is None:
15211541
if (

autoPyTorch/api/tabular_classification.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,6 @@ def search(
422422
dataset_name=dataset_name,
423423
dataset_compression=self._dataset_compression)
424424

425-
if self.dataset is None:
426-
raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__))
427-
428425
return self._search(
429426
dataset=self.dataset,
430427
optimize_metric=optimize_metric,

autoPyTorch/api/tabular_regression.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,9 +423,6 @@ def search(
423423
dataset_name=dataset_name,
424424
dataset_compression=self._dataset_compression)
425425

426-
if self.dataset is None:
427-
raise ValueError("`dataset` in {} must be initialized, but got None".format(self.__class__.__name__))
428-
429426
return self._search(
430427
dataset=self.dataset,
431428
optimize_metric=optimize_metric,

autoPyTorch/evaluation/fit_evaluator.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
AbstractEvaluator,
1717
fit_and_suppress_warnings
1818
)
19+
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
1920
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
2021
from autoPyTorch.utils.common import subsampler
2122
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
@@ -33,7 +34,7 @@ def __init__(self, backend: Backend, queue: Queue,
3334
num_run: Optional[int] = None,
3435
include: Optional[Dict[str, Any]] = None,
3536
exclude: Optional[Dict[str, Any]] = None,
36-
disable_file_output: Union[bool, List] = False,
37+
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
3738
init_params: Optional[Dict[str, Any]] = None,
3839
logger_port: Optional[int] = None,
3940
keep_models: Optional[bool] = None,
@@ -241,14 +242,11 @@ def file_output(
241242
)
242243

243244
# Abort if we don't want to output anything.
244-
if hasattr(self, 'disable_file_output'):
245-
if self.disable_file_output:
246-
return None, {}
247-
else:
248-
self.disabled_file_outputs = []
245+
if 'all' in self.disable_file_output:
246+
return None, {}
249247

250-
if hasattr(self, 'pipeline') and self.pipeline is not None:
251-
if 'pipeline' not in self.disabled_file_outputs:
248+
if getattr(self, 'pipeline', None) is not None:
249+
if 'pipeline' not in self.disable_file_output:
252250
pipeline = self.pipeline
253251
else:
254252
pipeline = None
@@ -265,11 +263,11 @@ def file_output(
265263
ensemble_predictions=None,
266264
valid_predictions=(
267265
Y_valid_pred if 'y_valid' not in
268-
self.disabled_file_outputs else None
266+
self.disable_file_output else None
269267
),
270268
test_predictions=(
271269
Y_test_pred if 'y_test' not in
272-
self.disabled_file_outputs else None
270+
self.disable_file_output else None
273271
),
274272
)
275273

@@ -287,8 +285,8 @@ def eval_function(
287285
num_run: int,
288286
include: Optional[Dict[str, Any]],
289287
exclude: Optional[Dict[str, Any]],
290-
disable_file_output: Union[bool, List],
291288
output_y_hat_optimization: bool = False,
289+
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
292290
pipeline_config: Optional[Dict[str, Any]] = None,
293291
budget_type: str = None,
294292
init_params: Optional[Dict[str, Any]] = None,
@@ -297,14 +295,75 @@ def eval_function(
297295
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
298296
instance: str = None,
299297
) -> None:
298+
"""
299+
This closure allows the communication between the ExecuteTaFuncWithQueue and the
300+
pipeline trainer (TrainEvaluator).
301+
302+
Fundamentally, smac calls the ExecuteTaFuncWithQueue.run() method, which internally
303+
builds a TrainEvaluator. The TrainEvaluator builds a pipeline, stores the output files
304+
to disc via the backend, and puts the performance result of the run in the queue.
305+
306+
307+
Attributes:
308+
backend (Backend):
309+
An object to interface with the disk storage. In particular, allows to
310+
access the train and test datasets
311+
queue (Queue):
312+
Each worker available will instantiate an evaluator, and after completion,
313+
it will return the evaluation result via a multiprocessing queue
314+
metric (autoPyTorchMetric):
315+
A scorer object that is able to evaluate how good a pipeline was fit. It
316+
is a wrapper on top of the actual score method (a wrapper on top of scikit
317+
lean accuracy for example) that formats the predictions accordingly.
318+
budget: (float):
319+
The amount of epochs/time a configuration is allowed to run.
320+
budget_type (str):
321+
The budget type, which can be epochs or time
322+
pipeline_config (Optional[Dict[str, Any]]):
323+
Defines the content of the pipeline being evaluated. For example, it
324+
contains pipeline specific settings like logging name, or whether or not
325+
to use tensorboard.
326+
config (Union[int, str, Configuration]):
327+
Determines the pipeline to be constructed.
328+
seed (int):
329+
A integer that allows for reproducibility of results
330+
output_y_hat_optimization (bool):
331+
Whether this worker should output the target predictions, so that they are
332+
stored on disk. Fundamentally, the resampling strategy might shuffle the
333+
Y_train targets, so we store the split in order to re-use them for ensemble
334+
selection.
335+
num_run (Optional[int]):
336+
An identifier of the current configuration being fit. This number is unique per
337+
configuration.
338+
include (Optional[Dict[str, Any]]):
339+
An optional dictionary to include components of the pipeline steps.
340+
exclude (Optional[Dict[str, Any]]):
341+
An optional dictionary to exclude components of the pipeline steps.
342+
disable_file_output (Union[bool, List[str]]):
343+
By default, the model, it's predictions and other metadata is stored on disk
344+
for each finished configuration. This argument allows the user to skip
345+
saving certain file type, for example the model, from being written to disk.
346+
init_params (Optional[Dict[str, Any]]):
347+
Optional argument that is passed to each pipeline step. It is the equivalent of
348+
kwargs for the pipeline steps.
349+
logger_port (Optional[int]):
350+
Logging is performed using a socket-server scheme to be robust against many
351+
parallel entities that want to write to the same file. This integer states the
352+
socket port for the communication channel. If None is provided, a traditional
353+
logger is used.
354+
instance (str):
355+
An instance on which to evaluate the current pipeline. By default we work
356+
with a single instance, being the provided X_train, y_train of a single dataset.
357+
This instance is a compatibility argument for SMAC, that is capable of working
358+
with multiple datasets at the same time.
359+
"""
300360
evaluator = FitEvaluator(
301361
backend=backend,
302362
queue=queue,
303363
metric=metric,
304364
configuration=config,
305365
seed=seed,
306366
num_run=num_run,
307-
output_y_hat_optimization=output_y_hat_optimization,
308367
include=include,
309368
exclude=exclude,
310369
disable_file_output=disable_file_output,

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,10 @@ def eval_train_function(
424424
budget: float,
425425
config: Optional[Configuration],
426426
seed: int,
427-
output_y_hat_optimization: bool,
428427
num_run: int,
429428
include: Optional[Dict[str, Any]],
430429
exclude: Optional[Dict[str, Any]],
430+
output_y_hat_optimization: bool,
431431
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
432432
pipeline_config: Optional[Dict[str, Any]] = None,
433433
budget_type: str = None,

0 commit comments

Comments
 (0)