Skip to content

Commit a717d60

Browse files
Comment fixes
1 parent 0606d34 commit a717d60

File tree

6 files changed

+58
-44
lines changed

6 files changed

+58
-44
lines changed

autoPyTorch/api/base_task.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
import typing
1111
import unittest.mock
12+
import uuid
1213
import warnings
1314
from abc import abstractmethod
1415
from typing import Any, Callable, Dict, List, Optional, Union, cast
@@ -707,7 +708,8 @@ def _search(
707708
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
708709
self._stopwatch.start_task(experiment_task_name)
709710
self.dataset_name = dataset.dataset_name
710-
self._logger = self._get_logger(self.dataset_name)
711+
if self._logger is None:
712+
self._logger = self._get_logger(self.dataset_name)
711713
self._all_supported_metrics = all_supported_metrics
712714
self._disable_file_output = disable_file_output
713715
self._memory_limit = memory_limit
@@ -907,8 +909,11 @@ def refit(
907909
Returns:
908910
self
909911
"""
912+
if self.dataset_name is None:
913+
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
910914

911-
self._logger = self._get_logger(dataset.dataset_name)
915+
if self._logger is None:
916+
self._logger = self._get_logger(self.dataset_name)
912917

913918
dataset_requirements = get_dataset_requirements(
914919
info=self._get_required_dataset_properties(dataset))
@@ -974,7 +979,11 @@ def fit(self,
974979
Returns:
975980
(BasePipeline): fitted pipeline
976981
"""
977-
self._logger = self._get_logger(dataset.dataset_name)
982+
if self.dataset_name is None:
983+
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
984+
985+
if self._logger is None:
986+
self._logger = self._get_logger(self.dataset_name)
978987

979988
# get dataset properties
980989
dataset_requirements = get_dataset_requirements(

autoPyTorch/api/tabular_classification.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import uuid
13
from typing import Any, Callable, Dict, List, Optional, Union
24

35
import numpy as np
@@ -86,10 +88,6 @@ def __init__(
8688
task_type=TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION],
8789
)
8890

89-
# Create a validator object to make sure that the data provided by
90-
# the user matches the autopytorch requirements
91-
self.InputValidator = TabularInputValidator(is_classification=True)
92-
9391
def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]:
9492
if not isinstance(dataset, TabularDataset):
9593
raise ValueError("Dataset is incompatible for the given task,: {}".format(
@@ -105,24 +103,25 @@ def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularClassific
105103
return TabularClassificationPipeline(dataset_properties=dataset_properties)
106104

107105
def search(
108-
self,
109-
optimize_metric: str,
110-
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
111-
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
112-
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
113-
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
114-
budget_type: Optional[str] = None,
115-
budget: Optional[float] = None,
116-
total_walltime_limit: int = 100,
117-
func_eval_time_limit: int = 60,
118-
traditional_per_total_budget: float = 0.1,
119-
memory_limit: Optional[int] = 4096,
120-
smac_scenario_args: Optional[Dict[str, Any]] = None,
121-
get_smac_object_callback: Optional[Callable] = None,
122-
all_supported_metrics: bool = True,
123-
precision: int = 32,
124-
disable_file_output: List = [],
125-
load_models: bool = True,
106+
self,
107+
optimize_metric: str,
108+
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
109+
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
110+
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
111+
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
112+
dataset_name: Optional[str] = None,
113+
budget_type: Optional[str] = None,
114+
budget: Optional[float] = None,
115+
total_walltime_limit: int = 100,
116+
func_eval_time_limit: int = 60,
117+
traditional_per_total_budget: float = 0.1,
118+
memory_limit: Optional[int] = 4096,
119+
smac_scenario_args: Optional[Dict[str, Any]] = None,
120+
get_smac_object_callback: Optional[Callable] = None,
121+
all_supported_metrics: bool = True,
122+
precision: int = 32,
123+
disable_file_output: List = [],
124+
load_models: bool = True,
126125
) -> 'BaseTask':
127126
"""
128127
Search for the best pipeline configuration for the given dataset.
@@ -133,9 +132,8 @@ def search(
133132
Args:
134133
X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame]
135134
A pair of features (X_train) and targets (y_train) used to fit a
136-
pipeline. Additionally, a holdout of this paris (X_test, y_test) can
135+
pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
137136
be provided to track the generalization performance of each stage.
138-
Providing X_train, y_train and dataset together is not supported.
139137
optimize_metric (str): name of the metric that is used to
140138
evaluate a pipeline.
141139
budget_type (Optional[str]):
@@ -189,6 +187,18 @@ def search(
189187
self
190188
191189
"""
190+
if dataset_name is None:
191+
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
192+
193+
# we have to create a logger for at this point for the validator
194+
self._logger = self._get_logger(dataset_name)
195+
196+
# Create a validator object to make sure that the data provided by
197+
# the user matches the autopytorch requirements
198+
self.InputValidator = TabularInputValidator(
199+
is_classification=True,
200+
logger_port=self._logger_port,
201+
)
192202

193203
# Fit a input validator to check the provided data
194204
# Also, an encoder is fit to both train and test data,
@@ -227,7 +237,7 @@ def predict(
227237
n_jobs: int = 1
228238
) -> np.ndarray:
229239
if self.InputValidator is None or not self.InputValidator._is_fitted:
230-
raise ValueError("predict() is only supported after calling fit. Kindly call first "
240+
raise ValueError("predict() is only supported after calling search. Kindly call first "
231241
"the estimator fit() method.")
232242

233243
X_test = self.InputValidator.feature_validator.transform(X_test)
@@ -247,7 +257,7 @@ def predict_proba(self,
247257
X_test: Union[np.ndarray, pd.DataFrame, List],
248258
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
249259
if self.InputValidator is None or not self.InputValidator._is_fitted:
250-
raise ValueError("predict() is only supported after calling fit. Kindly call first "
260+
raise ValueError("predict() is only supported after calling search. Kindly call first "
251261
"the estimator fit() method.")
252262
X_test = self.InputValidator.feature_validator.transform(X_test)
253263
return super().predict(X_test, batch_size=batch_size, n_jobs=n_jobs)

autoPyTorch/data/base_feature_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class BaseFeatureValidator(BaseEstimator):
3939
Host a encoder object if the data requires transformation (for example,
4040
if provided a categorical column in a pandas DataFrame)
4141
enc_columns (typing.List[str])
42-
List of columns that where encoded.
42+
List of columns that were encoded.
4343
"""
4444
def __init__(self,
4545
logger: typing.Optional[typing.Union[PicklableClientLogger, logging.Logger

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _fit(
7272

7373
# The column transformer reoders the feature types - we therefore need to change
7474
# it as well
75+
# This means columns are shifted to the right
7576
def comparator(cmp1: str, cmp2: str) -> int:
7677
if (
7778
cmp1 == 'categorical' and cmp2 == 'categorical'

autoPyTorch/datasets/tabular_dataset.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,16 @@ def __init__(self,
8181
# dataset.
8282
# TODO: Consider moving the validator to the pipeline itself when we
8383
# move to using the fit_params on scikit learn 0.24
84-
self.validator = validator
85-
if self.validator is None:
84+
if validator is None:
8685
raise ValueError("A feature validator is required to build a tabular pipeline")
8786

88-
X, Y = self.validator.transform(X, Y)
87+
X, Y = validator.transform(X, Y)
8988
if X_test is not None:
90-
X_test, Y_test = self.validator.transform(X_test, Y_test)
91-
self.categorical_columns = self.validator.feature_validator.categorical_columns
92-
self.numerical_columns = self.validator.feature_validator.numerical_columns
93-
self.num_features = self.validator.feature_validator.num_features
94-
self.categories = self.validator.feature_validator.categories
89+
X_test, Y_test = validator.transform(X_test, Y_test)
90+
self.categorical_columns = validator.feature_validator.categorical_columns
91+
self.numerical_columns = validator.feature_validator.numerical_columns
92+
self.num_features = validator.feature_validator.num_features
93+
self.categories = validator.feature_validator.categories
9594

9695
super().__init__(train_tensors=(X, Y), test_tensors=(X_test, Y_test), shuffle=shuffle,
9796
resampling_strategy=resampling_strategy,
@@ -122,8 +121,3 @@ def get_required_dataset_info(self) -> Dict[str, Any]:
122121
'task_type': self.task_type
123122
})
124123
return info
125-
126-
def __getstate__(self) -> Dict[str, Any]:
127-
# Make pickable!
128-
self.validator = None
129-
return self.__dict__

autoPyTorch/pipeline/base_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def fit(self, X: Dict[str, Any], y: Optional[np.ndarray] = None,
111111
"""Fit the selected algorithm to the training data.
112112
Arguments:
113113
X (typing.Dict):
114-
A fit dictionary so that contains information to fit a pipeline
114+
A fit dictionary that contains information to fit a pipeline
115115
TODO: Use fit_params support from 0.24 scikit learn version instead
116116
y (None):
117117
Used for Compatibility, but it has no funciton in out fit strategy

0 commit comments

Comments
 (0)