@@ -490,37 +490,23 @@ def __init__(self, backend: Backend,
490490 ))
491491
492492 self .additional_metrics : Optional [List [autoPyTorchMetric ]] = None
493+ metrics_dict : Optional [Dict [str , List [str ]]] = None
493494 if all_supported_metrics :
494495 self .additional_metrics = get_metrics (dataset_properties = self .dataset_properties ,
495496 all_supported_metrics = all_supported_metrics )
497+ # Update fit dictionary with metrics passed to the evaluator
498+ metrics_dict = {'additional_metrics' : []}
499+ metrics_dict ['additional_metrics' ].append (self .metric .name )
500+ for metric in self .additional_metrics :
501+ metrics_dict ['additional_metrics' ].append (metric .name )
496502
497- self .fit_dictionary : Dict [str , Any ] = {'dataset_properties' : self .dataset_properties }
498503 self ._init_params = init_params
499- self .fit_dictionary .update ({
500- 'X_train' : self .X_train ,
501- 'y_train' : self .y_train ,
502- 'X_test' : self .X_test ,
503- 'y_test' : self .y_test ,
504- 'backend' : self .backend ,
505- 'logger_port' : logger_port ,
506- 'optimize_metric' : self .metric .name
507- })
504+
508505 assert self .pipeline_class is not None , "Could not infer pipeline class"
509506 pipeline_config = pipeline_config if pipeline_config is not None \
510507 else self .pipeline_class .get_default_pipeline_options ()
511508 self .budget_type = pipeline_config ['budget_type' ] if budget_type is None else budget_type
512509 self .budget = pipeline_config [self .budget_type ] if budget == 0 else budget
513- self .fit_dictionary = {** pipeline_config , ** self .fit_dictionary }
514-
515- # If the budget is epochs, we want to limit that in the fit dictionary
516- if self .budget_type == 'epochs' :
517- self .fit_dictionary ['epochs' ] = budget
518- self .fit_dictionary .pop ('runtime' , None )
519- elif self .budget_type == 'runtime' :
520- self .fit_dictionary ['runtime' ] = budget
521- self .fit_dictionary .pop ('epochs' , None )
522- else :
523- raise ValueError (f"Unsupported budget type { self .budget_type } provided" )
524510
525511 self .num_run = 0 if num_run is None else num_run
526512
@@ -533,13 +519,65 @@ def __init__(self, backend: Backend,
533519 port = logger_port ,
534520 )
535521
522+ self ._init_fit_dictionary (logger_port = logger_port , pipeline_config = pipeline_config , metrics_dict = metrics_dict )
536523 self .Y_optimization : Optional [np .ndarray ] = None
537524 self .Y_actual_train : Optional [np .ndarray ] = None
538525 self .pipelines : Optional [List [BaseEstimator ]] = None
539526 self .pipeline : Optional [BaseEstimator ] = None
540527 self .logger .debug ("Fit dictionary in Abstract evaluator: {}" .format (dict_repr (self .fit_dictionary )))
541528 self .logger .debug ("Search space updates :{}" .format (self .search_space_updates ))
542529
530+ def _init_fit_dictionary (
531+ self ,
532+ logger_port : int ,
533+ pipeline_config : Dict [str , Any ],
534+ metrics_dict : Optional [Dict [str , List [str ]]] = None ,
535+ ) -> None :
536+ """
537+ Initialises the fit dictionary
538+
539+ Args:
540+ logger_port (int):
541+ Logging is performed using a socket-server scheme to be robust against many
542+ parallel entities that want to write to the same file. This integer states the
543+ socket port for the communication channel.
544+ pipeline_config (Dict[str, Any]):
545+ Defines the content of the pipeline being evaluated. For example, it
546+ contains pipeline specific settings like logging name, or whether or not
547+ to use tensorboard.
548+ metrics_dict (Optional[Dict[str, List[str]]]):
549+ Contains a list of metric names to be evaluated in Trainer with key `additional_metrics`. Defaults to None.
550+
551+ Returns:
552+ None
553+ """
554+
555+ self .fit_dictionary : Dict [str , Any ] = {'dataset_properties' : self .dataset_properties }
556+
557+ if metrics_dict is not None :
558+ self .fit_dictionary .update (metrics_dict )
559+
560+ self .fit_dictionary .update ({
561+ 'X_train' : self .X_train ,
562+ 'y_train' : self .y_train ,
563+ 'X_test' : self .X_test ,
564+ 'y_test' : self .y_test ,
565+ 'backend' : self .backend ,
566+ 'logger_port' : logger_port ,
567+ 'optimize_metric' : self .metric .name
568+ })
569+
570+ self .fit_dictionary .update (pipeline_config )
571+ # If the budget is epochs, we want to limit that in the fit dictionary
572+ if self .budget_type == 'epochs' :
573+ self .fit_dictionary ['epochs' ] = self .budget
574+ self .fit_dictionary .pop ('runtime' , None )
575+ elif self .budget_type == 'runtime' :
576+ self .fit_dictionary ['runtime' ] = self .budget
577+ self .fit_dictionary .pop ('epochs' , None )
578+ else :
579+ raise ValueError (f"budget type must be `epochs` or `runtime`, but got { self .budget_type } " )
580+
543581 def _get_pipeline (self ) -> BaseEstimator :
544582 """
545583 Implements a pipeline object based on the self.configuration attribute.
0 commit comments