@@ -26,25 +26,31 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626 np.ndarray: that processes data
2727 typing.Dict[str, np.ndarray]: arguments to the criterion function
2828 """
29- beta = 1.0
30- lam = self .random_state .beta (beta , beta )
31- batch_size = X .size ()[0 ]
32- index = torch .randperm (batch_size ).cuda () if X .is_cuda else torch .randperm (batch_size )
29+ alpha , beta = 1.0 , 1.0
30+ lam = self .random_state .beta (alpha , beta )
31+ batch_size = X .shape [0 ]
32+ device = torch .device ('cuda' if X .is_cuda else 'cpu' )
33+ batch_indices = torch .randperm (batch_size ).to (device )
3334
3435 r = self .random_state .rand (1 )
3536 if beta <= 0 or r > self .alpha :
36- return X , {'y_a' : y , 'y_b' : y [index ], 'lam' : 1 }
37+ return X , {'y_a' : y , 'y_b' : y [batch_indices ], 'lam' : 1 }
3738
38- size = X .shape [1 ]
39- indices = torch .tensor (self .random_state .choice (range (1 , size ), max (1 , np .int32 (size * lam )),
40- replace = False ))
39+ row_size = X .shape [1 ]
40+ row_indices = torch .tensor (
41+ self .random_state .choice (
42+ range (1 , row_size ),
43+ max (1 , int (row_size * lam )),
44+ replace = False
45+ )
46+ )
4147
42- X [:, indices ] = X [index , :][:, indices ]
48+ X [:, row_indices ] = X [batch_indices , :][:, row_indices ]
4349
4450 # Adjust lam
45- lam = 1 - (( len (indices )) / ( X . size () [1 ]))
51+ lam = 1 - len (row_indices ) / X . shape [1 ]
4652
47- y_a , y_b = y , y [index ]
53+ y_a , y_b = y , y [batch_indices ]
4854
4955 return X , {'y_a' : y_a , 'y_b' : y_b , 'lam' : lam }
5056
0 commit comments