@@ -45,12 +45,12 @@ def build_backbone(self, input_shape: Tuple[int, ...]) -> torch.nn.Sequential:
4545 # n_units for the architecture, since, it is mostly implemented for the
4646 # output layer, which is part of the head and not of the backbone.
4747 dropout_shape = get_shaped_neuron_counts (
48- shape = self .config ['resnet_shape ' ],
49- in_feat = 0 ,
50- out_feat = 0 ,
51- max_neurons = self . config [ "max_dropout" ],
52- layer_count = self .config ['num_groups' ] + 1 ,
53- )[: - 1 ]
48+ self .config ['dropout_shape ' ], 0 , 0 , 1000 , self . config [ 'num_groups' ]
49+ )
50+
51+ dropout_shape = [
52+ dropout / 1000 * self .config ["max_dropout" ] for dropout in dropout_shape
53+ ]
5454
5555 self .config .update (
5656 {"dropout_%d" % (i + 1 ): dropout for i , dropout in enumerate (dropout_shape )}
@@ -136,6 +136,13 @@ def get_hyperparameter_search_space( # type: ignore[override]
136136 max_dropout : HyperparameterSearchSpace = HyperparameterSearchSpace (hyperparameter = "max_dropout" ,
137137 value_range = (0 , 0.8 ),
138138 default_value = 0.5 ),
139+ dropout_shape : HyperparameterSearchSpace = HyperparameterSearchSpace (hyperparameter = "dropout_shape" ,
140+ value_range = ('funnel' , 'long_funnel' ,
141+ 'diamond' , 'hexagon' ,
142+ 'brick' , 'triangle' ,
143+ 'stairs' ),
144+ default_value = 'funnel' ,
145+ ),
139146 max_shake_drop_probability : HyperparameterSearchSpace = HyperparameterSearchSpace (
140147 hyperparameter = "max_shake_drop_probability" ,
141148 value_range = (0 , 1 ),
@@ -165,8 +172,10 @@ def get_hyperparameter_search_space( # type: ignore[override]
165172
166173 if dropout_flag :
167174 max_dropout = get_hyperparameter (max_dropout , UniformFloatHyperparameter )
168- cs .add_hyperparameter (max_dropout )
175+ dropout_shape = get_hyperparameter (dropout_shape , CategoricalHyperparameter )
176+ cs .add_hyperparameters ([dropout_shape , max_dropout ])
169177 cs .add_condition (CS .EqualsCondition (max_dropout , use_dropout , True ))
178+ cs .add_condition (CS .EqualsCondition (dropout_shape , use_dropout , True ))
170179
171180 skip_connection_flag = False
172181 if any (use_skip_connection .value_range ):
0 commit comments