@@ -67,7 +67,7 @@ def __init__(
6767 if isinstance (target , str ):
6868 self .y = self .y .reshape (- 1 , 1 ) # .astype(np.int64)
6969 else :
70- self .y = np .zeros ((self .n , 1 )) # .astype(np.int64)
70+ self .y = np .zeros ((self .n , 1 ), dtype = np . float32 ) # .astype(np.int64)
7171
7272 if task == "classification" :
7373 self .y = self .y .astype (np .int64 )
@@ -502,7 +502,7 @@ def _cache_dataset(self):
502502
503503 def split_train_val (self , train ):
504504 logger .debug (
505- "No validation data provided." f" Using { self .config .validation_split * 100 } % of train data as validation"
505+ f "No validation data provided. Using { self .config .validation_split * 100 } % of train data as validation"
506506 )
507507 val_idx = train .sample (
508508 int (self .config .validation_split * len (train )),
@@ -753,9 +753,7 @@ def _load_dataset_from_cache(self, tag: str = "train"):
753753 try :
754754 dataset = getattr (self , f"_{ tag } _dataset" )
755755 except AttributeError :
756- raise AttributeError (
757- f"{ tag } _dataset not found in memory. Please provide the data for" f" { tag } dataloader"
758- )
756+ raise AttributeError (f"{ tag } _dataset not found in memory. Please provide the data for { tag } dataloader" )
759757 elif self .cache_mode is self .CACHE_MODES .DISK :
760758 try :
761759 # get the torch version
@@ -768,10 +766,10 @@ def _load_dataset_from_cache(self, tag: str = "train"):
768766 dataset = torch .load (self .cache_dir / f"{ tag } _dataset" , weights_only = False )
769767 except FileNotFoundError :
770768 raise FileNotFoundError (
771- f"{ tag } _dataset not found in { self .cache_dir } . Please provide the" f" data for { tag } dataloader"
769+ f"{ tag } _dataset not found in { self .cache_dir } . Please provide the data for { tag } dataloader"
772770 )
773771 elif self .cache_mode is self .CACHE_MODES .INFERENCE :
774- raise RuntimeError ("Cannot load dataset in inference mode. Use" " `prepare_inference_dataloader` instead" )
772+ raise RuntimeError ("Cannot load dataset in inference mode. Use `prepare_inference_dataloader` instead" )
775773 else :
776774 raise ValueError (f"{ self .cache_mode } is not a valid cache mode" )
777775 return dataset
0 commit comments