@@ -26,14 +26,15 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2626 np.ndarray: that processes data
2727 typing.Dict[str, np.ndarray]: arguments to the criterion function
2828 """
29- beta = 1.0
30- lam = self .random_state .beta (beta , beta )
31- batch_size , channel , W , H = X .size ()
32- index = torch .randperm (batch_size ).cuda () if X .is_cuda else torch .randperm (batch_size )
29+ alpha , beta = 1.0 , 1.0
30+ lam = self .random_state .beta (alpha , beta )
31+ batch_size , _ , W , H = X .shape
32+ device = torch .device ('cuda' if X .is_cuda else 'cpu' )
33+ batch_indices = torch .randperm (batch_size ).to (device )
3334
3435 r = self .random_state .rand (1 )
3536 if beta <= 0 or r > self .alpha :
36- return X , {'y_a' : y , 'y_b' : y [index ], 'lam' : 1 }
37+ return X , {'y_a' : y , 'y_b' : y [batch_indices ], 'lam' : 1 }
3738
3839 # Draw parameters of a random bounding box
3940 # Where to cut basically
@@ -47,12 +48,13 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
4748 bbx2 = np .clip (cx + cut_w // 2 , 0 , W )
4849 bby2 = np .clip (cy + cut_h // 2 , 0 , H )
4950
50- X [:, :, bbx1 :bbx2 , bby1 :bby2 ] = X [index , :, bbx1 :bbx2 , bby1 :bby2 ]
51+ X [:, :, bbx1 :bbx2 , bby1 :bby2 ] = X [batch_indices , :, bbx1 :bbx2 , bby1 :bby2 ]
5152
5253 # Adjust lam
53- lam = 1 - ((bbx2 - bbx1 ) * (bby2 - bby1 ) / (X .size ()[- 1 ] * X .size ()[- 2 ]))
54+ pixel_size = W * H
55+ lam = 1 - ((bbx2 - bbx1 ) * (bby2 - bby1 ) / pixel_size )
5456
55- y_a , y_b = y , y [index ]
57+ y_a , y_b = y , y [batch_indices ]
5658
5759 return X , {'y_a' : y_a , 'y_b' : y_b , 'lam' : lam }
5860
0 commit comments