44
55
66import numpy as np
7+ import scipy .sparse
78import torch
89import torch .nn as nn
910import copy
2122from autoPyTorch .utils .config .config_file_parser import ConfigFileParser
2223
2324class AutoNet ():
25+ """Find an optimal neural network given a ML-task using BOHB"""
2426 preset_folder_name = None
2527
2628 def __init__ (self , config_preset = "medium_cs" , pipeline = None , ** autonet_config ):
@@ -34,6 +36,7 @@ def __init__(self, config_preset="medium_cs", pipeline=None, **autonet_config):
3436 self .base_config = autonet_config
3537 self .autonet_config = None
3638 self .fit_result = None
39+ self .dataset_info = None
3740
3841 if config_preset is not None :
3942 parser = self .get_autonet_config_file_parser ()
@@ -70,10 +73,11 @@ def get_current_autonet_config(self):
7073 return self .pipeline .get_pipeline_config (** self .base_config )
7174
7275 def get_hyperparameter_search_space (self , X_train = None , Y_train = None , X_valid = None , Y_valid = None , ** autonet_config ):
73- """Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset and the configuration.!
76+ """Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset and the configuration!
77+ You can either pass the dataset and the configuration or use dataset and configuration of last fit call.
7478
7579 Keyword Arguments:
76- X_train {array} -- Training data.
80+ X_train {array} -- Training data. ConfigSpace depends on Training data.
7781 Y_train {array} -- Targets of training data.
7882 X_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
7983 Y_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
@@ -82,8 +86,8 @@ def get_hyperparameter_search_space(self, X_train=None, Y_train=None, X_valid=No
8286 Returns:
8387 ConfigurationSpace -- The configuration space that should be optimized.
8488 """
85-
86- dataset_info = None
89+ X_train , Y_train , X_valid , Y_valid = self . check_data_array_types ( X_train , Y_train , X_valid , Y_valid )
90+ dataset_info = self . dataset_info
8791 pipeline_config = dict (self .base_config , ** autonet_config ) if autonet_config else \
8892 self .get_current_autonet_config ()
8993 if X_train is not None and Y_train is not None :
@@ -129,21 +133,22 @@ def fit(self, X_train, Y_train, X_valid=None, Y_valid=None, refit=True, **autone
129133
130134 Returns:
131135 optimized_hyperparameter_config -- The best found hyperparameter config.
132- final_metric_score -- The final score of the specified train metric.
133136 **autonet_config -- Configure AutoNet for your needs. You can also configure AutoNet in the constructor(). Call print_help() for more info.
134137 """
138+ X_train , Y_train , X_valid , Y_valid = self .check_data_array_types (X_train , Y_train , X_valid , Y_valid )
135139 self .autonet_config = self .pipeline .get_pipeline_config (** dict (self .base_config , ** autonet_config ))
136140
137141 self .fit_result = self .pipeline .fit_pipeline (pipeline_config = self .autonet_config ,
138142 X_train = X_train , Y_train = Y_train , X_valid = X_valid , Y_valid = Y_valid )
143+ self .dataset_info = self .pipeline [CreateDatasetInfo .get_name ()].fit_output ["dataset_info" ]
139144 self .pipeline .clean ()
140145
141146 if not self .fit_result ["optimized_hyperparameter_config" ]:
142147 raise RuntimeError ("No models fit during training, please retry with a larger max_runtime." )
143148
144149 if (refit ):
145150 self .refit (X_train , Y_train , X_valid , Y_valid )
146- return self .fit_result [ "optimized_hyperparameter_config" ], self . fit_result [ 'final_metric_score' ]
151+ return self .fit_result
147152
148153 def refit (self , X_train , Y_train , X_valid = None , Y_valid = None , hyperparameter_config = None , autonet_config = None , budget = None , rescore = False ):
149154 """Refit AutoNet to given hyperparameters. This will skip hyperparameter search.
@@ -163,6 +168,7 @@ def refit(self, X_train, Y_train, X_valid=None, Y_valid=None, hyperparameter_con
163168 Raises:
164169 ValueError -- No hyperparameter config available
165170 """
171+ X_train , Y_train , X_valid , Y_valid = self .check_data_array_types (X_train , Y_train , X_valid , Y_valid )
166172 if (autonet_config is None ):
167173 autonet_config = self .autonet_config
168174 if (autonet_config is None ):
@@ -182,9 +188,8 @@ def refit(self, X_train, Y_train, X_valid=None, Y_valid=None, hyperparameter_con
182188 'budget' : budget ,
183189 'rescore' : rescore }
184190
185- result = self .pipeline .fit_pipeline (pipeline_config = autonet_config , refit = refit_data ,
186- X_train = X_train , Y_train = Y_train , X_valid = X_valid , Y_valid = Y_valid )
187- return result ["final_metric_score" ]
191+ return self .pipeline .fit_pipeline (pipeline_config = autonet_config , refit = refit_data ,
192+ X_train = X_train , Y_train = Y_train , X_valid = X_valid , Y_valid = Y_valid )
188193
189194 def predict (self , X , return_probabilities = False ):
190195 """Predict the targets for a data matrix X.
@@ -200,6 +205,7 @@ def predict(self, X, return_probabilities=False):
200205 """
201206
202207 # run predict pipeline
208+ X , = self .check_data_array_types (X )
203209 autonet_config = self .autonet_config or self .base_config
204210 Y_pred = self .pipeline .predict_pipeline (pipeline_config = autonet_config , X = X )['Y' ]
205211
@@ -208,8 +214,8 @@ def predict(self, X, return_probabilities=False):
208214 result = OHE .reverse_transform_y (Y_pred , OHE .fit_output ['y_one_hot_encoder' ])
209215 return result if not return_probabilities else (result , Y_pred )
210216
211- def score (self , X_test , Y_test ):
212- """Calculate the sore on test data using the specified train_metric
217+ def score (self , X_test , Y_test , return_loss_value = False ):
218+ """Calculate the sore on test data using the specified optimize_metric
213219
214220 Arguments:
215221 X_test {array} -- The test data matrix.
@@ -220,6 +226,7 @@ def score(self, X_test, Y_test):
220226 """
221227
222228 # run predict pipeline
229+ X_test , Y_test = self .check_data_array_types (X_test , Y_test )
223230 autonet_config = self .autonet_config or self .base_config
224231 self .pipeline .predict_pipeline (pipeline_config = autonet_config , X = X_test )
225232 Y_pred = self .pipeline [OptimizationAlgorithm .get_name ()].predict_output ['Y' ]
@@ -228,5 +235,19 @@ def score(self, X_test, Y_test):
228235 OHE = self .pipeline [OneHotEncoding .get_name ()]
229236 Y_test = OHE .transform_y (Y_test , OHE .fit_output ['y_one_hot_encoder' ])
230237
231- metric = self .pipeline [MetricSelector .get_name ()].fit_output ['train_metric' ]
238+ metric = self .pipeline [MetricSelector .get_name ()].fit_output ['optimize_metric' ]
239+ if return_loss_value :
240+ return metric .get_loss_value (Y_pred , Y_test )
232241 return metric (Y_pred , Y_test )
242+
243+ def check_data_array_types (self , * arrays ):
244+ result = []
245+ for array in arrays :
246+ if array is None or scipy .sparse .issparse (array ):
247+ result .append (array )
248+ continue
249+
250+ result .append (np .asanyarray (array ))
251+ if not result [- 1 ].shape :
252+ raise RuntimeError ("Given data-array is of unexpected type %s. Please pass numpy arrays instead." % type (array ))
253+ return result
0 commit comments