@@ -70,28 +70,31 @@ def get_current_autonet_config(self):
7070 return self .pipeline .get_pipeline_config (** self .base_config )
7171
7272 def get_hyperparameter_search_space (self , X_train = None , Y_train = None , X_valid = None , Y_valid = None , ** autonet_config ):
73- """Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset!
73+ """Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset and the configuration. !
7474
7575 Keyword Arguments:
7676 X_train {array} -- Training data.
7777 Y_train {array} -- Targets of training data.
7878 X_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
7979 Y_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
80+ autonet_config{dict} -- if not given and fit already called, config of last fit will be used
8081
8182 Returns:
8283 ConfigurationSpace -- The configuration space that should be optimized.
8384 """
8485
8586 dataset_info = None
87+ pipeline_config = dict (self .base_config , ** autonet_config ) if autonet_config else \
88+ self .get_current_autonet_config ()
8689 if X_train is not None and Y_train is not None :
8790 dataset_info_node = self .pipeline [CreateDatasetInfo .get_name ()]
88- dataset_info = dataset_info_node .fit (pipeline_config = dict ( self . base_config , ** autonet_config ) ,
91+ dataset_info = dataset_info_node .fit (pipeline_config = pipeline_config ,
8992 X_train = X_train ,
9093 Y_train = Y_train ,
9194 X_valid = X_valid ,
9295 Y_valid = Y_valid )["dataset_info" ]
9396
94- return self .pipeline .get_hyperparameter_search_space (dataset_info = dataset_info , ** self . get_current_autonet_config () )
97+ return self .pipeline .get_hyperparameter_search_space (dataset_info = dataset_info , ** pipeline_config )
9598
9699 @classmethod
97100 def get_default_pipeline (cls ):
0 commit comments