@@ -64,7 +64,8 @@ def _add_group(self, in_features: int, out_features: int,
6464 out_features (int): output dimensionality for the current block
6565 blocks_per_group (int): Number of ResNet per group
6666 last_block_index (int): block index for shake regularization
67- dropout (bool): whether or not use dropout
67+ dropout (None, float): dropout value for the group. If none,
68+ no dropout is applied.
6869 """
6970 blocks = list ()
7071 for i in range (blocks_per_group ):
@@ -180,9 +181,7 @@ def get_hyperparameter_search_space(
180181
181182 if skip_connection_flag :
182183
183- shake_drop_prob_flag = False
184- if 'shake-drop' in multi_branch_choice .value_range :
185- shake_drop_prob_flag = True
184+ shake_drop_prob_flag = 'shake-drop' in multi_branch_choice .value_range
186185
187186 mb_choice = get_hyperparameter (multi_branch_choice , CategoricalHyperparameter )
188187 cs .add_hyperparameter (mb_choice )
@@ -290,13 +289,21 @@ def _build_block(self, in_features: int, out_features: int) -> nn.Module:
290289 if self .config ['use_batch_norm' ]:
291290 layers .append (nn .BatchNorm1d (in_features ))
292291 layers .append (self .activation ())
292+ elif not self .config ['use_skip_connection' ]:
293+ # if start norm is not None and skip connection is False
294+ # we will never apply the start_norm for the first layer in the block,
295+ # which is why we should account for this case.
296+ if self .config ['use_batch_norm' ]:
297+ layers .append (nn .BatchNorm1d (in_features ))
298+ layers .append (self .activation ())
299+
293300 layers .append (nn .Linear (in_features , out_features ))
294301
295302 if self .config ['use_batch_norm' ]:
296303 layers .append (nn .BatchNorm1d (out_features ))
297304 layers .append (self .activation ())
298305
299- if self .config [ "use_dropout" ] :
306+ if self .dropout is not None :
300307 layers .append (nn .Dropout (self .dropout ))
301308 layers .append (nn .Linear (out_features , out_features ))
302309
@@ -321,6 +328,7 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
321328 if self .config ["use_skip_connection" ]:
322329 residual = self .shortcut (x )
323330
331+ # TODO make the below code better
324332 if self .config ["use_skip_connection" ]:
325333 if self .config ["multi_branch_choice" ] == 'shake-shake' :
326334 x1 = self .layers (x )
0 commit comments