@@ -433,34 +433,16 @@ def __init__(self, backend: Backend,
433433 self .backend : Backend = backend
434434 self .queue = queue
435435
436- self .datamanager : BaseDataset = self .backend .load_datamanager ()
437-
438- assert self .datamanager .task_type is not None , \
439- "Expected dataset {} to have task_type got None" .format (self .datamanager .__class__ .__name__ )
440- self .task_type = STRING_TO_TASK_TYPES [self .datamanager .task_type ]
441- self .output_type = STRING_TO_OUTPUT_TYPES [self .datamanager .output_type ]
442- self .issparse = self .datamanager .issparse
443-
444436 self .include = include
445437 self .exclude = exclude
446438 self .search_space_updates = search_space_updates
447439
448- self .X_train , self .y_train = self .datamanager .train_tensors
449-
450- if self .datamanager .val_tensors is not None :
451- self .X_valid , self .y_valid = self .datamanager .val_tensors
452- else :
453- self .X_valid , self .y_valid = None , None
454-
455- if self .datamanager .test_tensors is not None :
456- self .X_test , self .y_test = self .datamanager .test_tensors
457- else :
458- self .X_test , self .y_test = None , None
459-
460440 self .metric = metric
461441
462442 self .seed = seed
463443
444+ self ._init_datamanager_info ()
445+
464446 # Flag to save target for ensemble
465447 self .output_y_hat_optimization = output_y_hat_optimization
466448
@@ -497,12 +479,6 @@ def __init__(self, backend: Backend,
497479 else :
498480 raise ValueError ('task {} not available' .format (self .task_type ))
499481 self .predict_function = self ._predict_proba
500- self .dataset_properties = self .datamanager .get_dataset_properties (
501- get_dataset_requirements (info = self .datamanager .get_required_dataset_info (),
502- include = self .include ,
503- exclude = self .exclude ,
504- search_space_updates = self .search_space_updates
505- ))
506482
507483 self .additional_metrics : Optional [List [autoPyTorchMetric ]] = None
508484 metrics_dict : Optional [Dict [str , List [str ]]] = None
@@ -542,6 +518,53 @@ def __init__(self, backend: Backend,
542518 self .logger .debug ("Fit dictionary in Abstract evaluator: {}" .format (dict_repr (self .fit_dictionary )))
543519 self .logger .debug ("Search space updates :{}" .format (self .search_space_updates ))
544520
521+ def _init_datamanager_info (
522+ self ,
523+ ) -> None :
524+ """
525+ Initialises instance attributes that come from the datamanager.
526+ For example,
527+ X_train, y_train, etc.
528+ """
529+
530+ datamanager : BaseDataset = self .backend .load_datamanager ()
531+
532+ assert datamanager .task_type is not None , \
533+ "Expected dataset {} to have task_type got None" .format (datamanager .__class__ .__name__ )
534+ self .task_type = STRING_TO_TASK_TYPES [datamanager .task_type ]
535+ self .output_type = STRING_TO_OUTPUT_TYPES [datamanager .output_type ]
536+ self .issparse = datamanager .issparse
537+
538+ self .X_train , self .y_train = datamanager .train_tensors
539+
540+ if datamanager .val_tensors is not None :
541+ self .X_valid , self .y_valid = datamanager .val_tensors
542+ else :
543+ self .X_valid , self .y_valid = None , None
544+
545+ if datamanager .test_tensors is not None :
546+ self .X_test , self .y_test = datamanager .test_tensors
547+ else :
548+ self .X_test , self .y_test = None , None
549+
550+ self .resampling_strategy = datamanager .resampling_strategy
551+
552+ self .num_classes : Optional [int ] = getattr (datamanager , "num_classes" , None )
553+
554+ self .dataset_properties = datamanager .get_dataset_properties (
555+ get_dataset_requirements (info = datamanager .get_required_dataset_info (),
556+ include = self .include ,
557+ exclude = self .exclude ,
558+ search_space_updates = self .search_space_updates
559+ ))
560+ self .splits = datamanager .splits
561+ if self .splits is None :
562+ raise AttributeError (f"create_splits on { datamanager .__class__ .__name__ } must be called "
563+ f"before the instantiation of { self .__class__ .__name__ } " )
564+
565+ # delete datamanager from memory
566+ del datamanager
567+
545568 def _init_fit_dictionary (
546569 self ,
547570 logger_port : int ,
@@ -988,21 +1011,20 @@ def _ensure_prediction_array_sizes(self, prediction: np.ndarray,
9881011 (np.ndarray):
9891012 The formatted prediction
9901013 """
991- assert self .datamanager .num_classes is not None , "Called function on wrong task"
992- num_classes : int = self .datamanager .num_classes
1014+ assert self .num_classes is not None , "Called function on wrong task"
9931015
9941016 if self .output_type == MULTICLASS and \
995- prediction .shape [1 ] < num_classes :
1017+ prediction .shape [1 ] < self . num_classes :
9961018 if Y_train is None :
9971019 raise ValueError ('Y_train must not be None!' )
9981020 classes = list (np .unique (Y_train ))
9991021
10001022 mapping = dict ()
1001- for class_number in range (num_classes ):
1023+ for class_number in range (self . num_classes ):
10021024 if class_number in classes :
10031025 index = classes .index (class_number )
10041026 mapping [index ] = class_number
1005- new_predictions = np .zeros ((prediction .shape [0 ], num_classes ),
1027+ new_predictions = np .zeros ((prediction .shape [0 ], self . num_classes ),
10061028 dtype = np .float32 )
10071029
10081030 for index in mapping :
0 commit comments