2222
2323
2424def drop_block_2d (
25- x , drop_prob : float = 0.1 , training : bool = False , block_size : int = 7 ,
26- gamma_scale : float = 1.0 , drop_with_noise : bool = False ):
25+ x , drop_prob : float = 0.1 , block_size : int = 7 , gamma_scale : float = 1.0 ,
26+ with_noise : bool = False , inplace : bool = False , batchwise : bool = False ):
2727 """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
2828
2929 DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
3030 runs with success, but needs further validation and possibly optimization for lower runtime impact.
31-
3231 """
33- if drop_prob == 0. or not training :
34- return x
35- _ , _ , height , width = x .shape
36- total_size = width * height
37- clipped_block_size = min (block_size , min (width , height ))
32+ B , C , H , W = x .shape
33+ total_size = W * H
34+ clipped_block_size = min (block_size , min (W , H ))
3835 # seed_drop_rate, the gamma parameter
39- seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
40- (width - block_size + 1 ) *
41- (height - block_size + 1 ))
36+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
37+ (W - block_size + 1 ) * (H - block_size + 1 ))
4238
4339 # Forces the block to be inside the feature map.
44- w_i , h_i = torch .meshgrid (torch .arange (width ).to (x .device ), torch .arange (height ).to (x .device ))
45- valid_block = ((w_i >= clipped_block_size // 2 ) & (w_i < width - (clipped_block_size - 1 ) // 2 )) & \
46- ((h_i >= clipped_block_size // 2 ) & (h_i < height - (clipped_block_size - 1 ) // 2 ))
47- valid_block = torch .reshape (valid_block , (1 , 1 , height , width )).float ()
48-
49- uniform_noise = torch .rand_like (x , dtype = torch .float32 )
50- block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise ) >= 1 ).float ()
40+ w_i , h_i = torch .meshgrid (torch .arange (W ).to (x .device ), torch .arange (H ).to (x .device ))
41+ valid_block = ((w_i >= clipped_block_size // 2 ) & (w_i < W - (clipped_block_size - 1 ) // 2 )) & \
42+ ((h_i >= clipped_block_size // 2 ) & (h_i < H - (clipped_block_size - 1 ) // 2 ))
43+ valid_block = torch .reshape (valid_block , (1 , 1 , H , W )).to (dtype = x .dtype )
44+
45+ if batchwise :
46+ # one mask for whole batch, quite a bit faster
47+ uniform_noise = torch .rand ((1 , C , H , W ), dtype = x .dtype , device = x .device )
48+ else :
49+ uniform_noise = torch .rand_like (x )
50+ block_mask = ((2 - gamma - valid_block + uniform_noise ) >= 1 ).to (dtype = x .dtype )
5151 block_mask = - F .max_pool2d (
5252 - block_mask ,
53- kernel_size = clipped_block_size , # block_size, ???
53+ kernel_size = clipped_block_size , # block_size,
5454 stride = 1 ,
5555 padding = clipped_block_size // 2 )
5656
57- if drop_with_noise :
58- normal_noise = torch .randn_like (x )
59- x = x * block_mask + normal_noise * (1 - block_mask )
57+ if with_noise :
58+ normal_noise = torch .randn ((1 , C , H , W ), dtype = x .dtype , device = x .device ) if batchwise else torch .randn_like (x )
59+ if inplace :
60+ x .mul_ (block_mask ).add_ (normal_noise * (1 - block_mask ))
61+ else :
62+ x = x * block_mask + normal_noise * (1 - block_mask )
63+ else :
64+ normalize_scale = (block_mask .numel () / block_mask .to (dtype = torch .float32 ).sum ().add (1e-7 )).to (x .dtype )
65+ if inplace :
66+ x .mul_ (block_mask * normalize_scale )
67+ else :
68+ x = x * block_mask * normalize_scale
69+ return x
70+
71+
72+ def drop_block_fast_2d (
73+ x : torch .Tensor , drop_prob : float = 0.1 , block_size : int = 7 ,
74+ gamma_scale : float = 1.0 , with_noise : bool = False , inplace : bool = False , batchwise : bool = False ):
75+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
76+
77+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
78+ block mask at edges.
79+ """
80+ B , C , H , W = x .shape
81+ total_size = W * H
82+ clipped_block_size = min (block_size , min (W , H ))
83+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
84+ (W - block_size + 1 ) * (H - block_size + 1 ))
85+
86+ if batchwise :
87+ # one mask for whole batch, quite a bit faster
88+ block_mask = torch .rand ((1 , C , H , W ), dtype = x .dtype , device = x .device ) < gamma
89+ else :
90+ # mask per batch element
91+ block_mask = torch .rand_like (x ) < gamma
92+ block_mask = F .max_pool2d (
93+ block_mask .to (x .dtype ), kernel_size = clipped_block_size , stride = 1 , padding = clipped_block_size // 2 )
94+
95+ if with_noise :
96+ normal_noise = torch .randn ((1 , C , H , W ), dtype = x .dtype , device = x .device ) if batchwise else torch .randn_like (x )
97+ if inplace :
98+ x .mul_ (1. - block_mask ).add_ (normal_noise * block_mask )
99+ else :
100+ x = x * (1. - block_mask ) + normal_noise * block_mask
60101 else :
61- normalize_scale = block_mask .numel () / (torch .sum (block_mask ) + 1e-7 )
62- x = x * block_mask * normalize_scale
102+ block_mask = 1 - block_mask
103+ normalize_scale = (block_mask .numel () / block_mask .to (dtype = torch .float32 ).sum ().add (1e-7 )).to (dtype = x .dtype )
104+ if inplace :
105+ x .mul_ (block_mask * normalize_scale )
106+ else :
107+ x = x * block_mask * normalize_scale
63108 return x
64109
65110
@@ -70,15 +115,28 @@ def __init__(self,
70115 drop_prob = 0.1 ,
71116 block_size = 7 ,
72117 gamma_scale = 1.0 ,
73- with_noise = False ):
118+ with_noise = False ,
119+ inplace = False ,
120+ batchwise = False ,
121+ fast = True ):
74122 super (DropBlock2d , self ).__init__ ()
75123 self .drop_prob = drop_prob
76124 self .gamma_scale = gamma_scale
77125 self .block_size = block_size
78126 self .with_noise = with_noise
127+ self .inplace = inplace
128+ self .batchwise = batchwise
129+ self .fast = fast # FIXME finish comparisons of fast vs not
79130
80131 def forward (self , x ):
81- return drop_block_2d (x , self .drop_prob , self .training , self .block_size , self .gamma_scale , self .with_noise )
132+ if not self .training or not self .drop_prob :
133+ return x
134+ if self .fast :
135+ return drop_block_fast_2d (
136+ x , self .drop_prob , self .block_size , self .gamma_scale , self .with_noise , self .inplace , self .batchwise )
137+ else :
138+ return drop_block_2d (
139+ x , self .drop_prob , self .block_size , self .gamma_scale , self .with_noise , self .inplace , self .batchwise )
82140
83141
84142def drop_path (x , drop_prob : float = 0. , training : bool = False ):
0 commit comments