11from ._base import EncoderMixin
2- from timm .models .regnet import RegNet
2+ from timm .models .regnet import RegNet , RegNetCfg
33import torch .nn as nn
44
55
66class RegNetEncoder (RegNet , EncoderMixin ):
77 def __init__ (self , out_channels , depth = 5 , ** kwargs ):
8+ kwargs ["cfg" ] = RegNetCfg (** kwargs ["cfg" ])
89 super ().__init__ (** kwargs )
910 self ._depth = depth
1011 self ._out_channels = out_channels
@@ -141,95 +142,95 @@ def _mcfg(**kwargs):
141142 "pretrained_settings" : pretrained_settings ["timm-regnetx_002" ],
142143 "params" : {
143144 "out_channels" : (3 , 32 , 24 , 56 , 152 , 368 ),
144- "cfg" : _mcfg (w0 = 24 , wa = 36.44 , wm = 2.49 , group_w = 8 , depth = 13 ),
145+ "cfg" : _mcfg (w0 = 24 , wa = 36.44 , wm = 2.49 , group_size = 8 , depth = 13 ),
145146 },
146147 },
147148 "timm-regnetx_004" : {
148149 "encoder" : RegNetEncoder ,
149150 "pretrained_settings" : pretrained_settings ["timm-regnetx_004" ],
150151 "params" : {
151152 "out_channels" : (3 , 32 , 32 , 64 , 160 , 384 ),
152- "cfg" : _mcfg (w0 = 24 , wa = 24.48 , wm = 2.54 , group_w = 16 , depth = 22 ),
153+ "cfg" : _mcfg (w0 = 24 , wa = 24.48 , wm = 2.54 , group_size = 16 , depth = 22 ),
153154 },
154155 },
155156 "timm-regnetx_006" : {
156157 "encoder" : RegNetEncoder ,
157158 "pretrained_settings" : pretrained_settings ["timm-regnetx_006" ],
158159 "params" : {
159160 "out_channels" : (3 , 32 , 48 , 96 , 240 , 528 ),
160- "cfg" : _mcfg (w0 = 48 , wa = 36.97 , wm = 2.24 , group_w = 24 , depth = 16 ),
161+ "cfg" : _mcfg (w0 = 48 , wa = 36.97 , wm = 2.24 , group_size = 24 , depth = 16 ),
161162 },
162163 },
163164 "timm-regnetx_008" : {
164165 "encoder" : RegNetEncoder ,
165166 "pretrained_settings" : pretrained_settings ["timm-regnetx_008" ],
166167 "params" : {
167168 "out_channels" : (3 , 32 , 64 , 128 , 288 , 672 ),
168- "cfg" : _mcfg (w0 = 56 , wa = 35.73 , wm = 2.28 , group_w = 16 , depth = 16 ),
169+ "cfg" : _mcfg (w0 = 56 , wa = 35.73 , wm = 2.28 , group_size = 16 , depth = 16 ),
169170 },
170171 },
171172 "timm-regnetx_016" : {
172173 "encoder" : RegNetEncoder ,
173174 "pretrained_settings" : pretrained_settings ["timm-regnetx_016" ],
174175 "params" : {
175176 "out_channels" : (3 , 32 , 72 , 168 , 408 , 912 ),
176- "cfg" : _mcfg (w0 = 80 , wa = 34.01 , wm = 2.25 , group_w = 24 , depth = 18 ),
177+ "cfg" : _mcfg (w0 = 80 , wa = 34.01 , wm = 2.25 , group_size = 24 , depth = 18 ),
177178 },
178179 },
179180 "timm-regnetx_032" : {
180181 "encoder" : RegNetEncoder ,
181182 "pretrained_settings" : pretrained_settings ["timm-regnetx_032" ],
182183 "params" : {
183184 "out_channels" : (3 , 32 , 96 , 192 , 432 , 1008 ),
184- "cfg" : _mcfg (w0 = 88 , wa = 26.31 , wm = 2.25 , group_w = 48 , depth = 25 ),
185+ "cfg" : _mcfg (w0 = 88 , wa = 26.31 , wm = 2.25 , group_size = 48 , depth = 25 ),
185186 },
186187 },
187188 "timm-regnetx_040" : {
188189 "encoder" : RegNetEncoder ,
189190 "pretrained_settings" : pretrained_settings ["timm-regnetx_040" ],
190191 "params" : {
191192 "out_channels" : (3 , 32 , 80 , 240 , 560 , 1360 ),
192- "cfg" : _mcfg (w0 = 96 , wa = 38.65 , wm = 2.43 , group_w = 40 , depth = 23 ),
193+ "cfg" : _mcfg (w0 = 96 , wa = 38.65 , wm = 2.43 , group_size = 40 , depth = 23 ),
193194 },
194195 },
195196 "timm-regnetx_064" : {
196197 "encoder" : RegNetEncoder ,
197198 "pretrained_settings" : pretrained_settings ["timm-regnetx_064" ],
198199 "params" : {
199200 "out_channels" : (3 , 32 , 168 , 392 , 784 , 1624 ),
200- "cfg" : _mcfg (w0 = 184 , wa = 60.83 , wm = 2.07 , group_w = 56 , depth = 17 ),
201+ "cfg" : _mcfg (w0 = 184 , wa = 60.83 , wm = 2.07 , group_size = 56 , depth = 17 ),
201202 },
202203 },
203204 "timm-regnetx_080" : {
204205 "encoder" : RegNetEncoder ,
205206 "pretrained_settings" : pretrained_settings ["timm-regnetx_080" ],
206207 "params" : {
207208 "out_channels" : (3 , 32 , 80 , 240 , 720 , 1920 ),
208- "cfg" : _mcfg (w0 = 80 , wa = 49.56 , wm = 2.88 , group_w = 120 , depth = 23 ),
209+ "cfg" : _mcfg (w0 = 80 , wa = 49.56 , wm = 2.88 , group_size = 120 , depth = 23 ),
209210 },
210211 },
211212 "timm-regnetx_120" : {
212213 "encoder" : RegNetEncoder ,
213214 "pretrained_settings" : pretrained_settings ["timm-regnetx_120" ],
214215 "params" : {
215216 "out_channels" : (3 , 32 , 224 , 448 , 896 , 2240 ),
216- "cfg" : _mcfg (w0 = 168 , wa = 73.36 , wm = 2.37 , group_w = 112 , depth = 19 ),
217+ "cfg" : _mcfg (w0 = 168 , wa = 73.36 , wm = 2.37 , group_size = 112 , depth = 19 ),
217218 },
218219 },
219220 "timm-regnetx_160" : {
220221 "encoder" : RegNetEncoder ,
221222 "pretrained_settings" : pretrained_settings ["timm-regnetx_160" ],
222223 "params" : {
223224 "out_channels" : (3 , 32 , 256 , 512 , 896 , 2048 ),
224- "cfg" : _mcfg (w0 = 216 , wa = 55.59 , wm = 2.1 , group_w = 128 , depth = 22 ),
225+ "cfg" : _mcfg (w0 = 216 , wa = 55.59 , wm = 2.1 , group_size = 128 , depth = 22 ),
225226 },
226227 },
227228 "timm-regnetx_320" : {
228229 "encoder" : RegNetEncoder ,
229230 "pretrained_settings" : pretrained_settings ["timm-regnetx_320" ],
230231 "params" : {
231232 "out_channels" : (3 , 32 , 336 , 672 , 1344 , 2520 ),
232- "cfg" : _mcfg (w0 = 320 , wa = 69.86 , wm = 2.0 , group_w = 168 , depth = 23 ),
233+ "cfg" : _mcfg (w0 = 320 , wa = 69.86 , wm = 2.0 , group_size = 168 , depth = 23 ),
233234 },
234235 },
235236 # regnety
@@ -238,95 +239,95 @@ def _mcfg(**kwargs):
238239 "pretrained_settings" : pretrained_settings ["timm-regnety_002" ],
239240 "params" : {
240241 "out_channels" : (3 , 32 , 24 , 56 , 152 , 368 ),
241- "cfg" : _mcfg (w0 = 24 , wa = 36.44 , wm = 2.49 , group_w = 8 , depth = 13 , se_ratio = 0.25 ),
242+ "cfg" : _mcfg (w0 = 24 , wa = 36.44 , wm = 2.49 , group_size = 8 , depth = 13 , se_ratio = 0.25 ),
242243 },
243244 },
244245 "timm-regnety_004" : {
245246 "encoder" : RegNetEncoder ,
246247 "pretrained_settings" : pretrained_settings ["timm-regnety_004" ],
247248 "params" : {
248249 "out_channels" : (3 , 32 , 48 , 104 , 208 , 440 ),
249- "cfg" : _mcfg (w0 = 48 , wa = 27.89 , wm = 2.09 , group_w = 8 , depth = 16 , se_ratio = 0.25 ),
250+ "cfg" : _mcfg (w0 = 48 , wa = 27.89 , wm = 2.09 , group_size = 8 , depth = 16 , se_ratio = 0.25 ),
250251 },
251252 },
252253 "timm-regnety_006" : {
253254 "encoder" : RegNetEncoder ,
254255 "pretrained_settings" : pretrained_settings ["timm-regnety_006" ],
255256 "params" : {
256257 "out_channels" : (3 , 32 , 48 , 112 , 256 , 608 ),
257- "cfg" : _mcfg (w0 = 48 , wa = 32.54 , wm = 2.32 , group_w = 16 , depth = 15 , se_ratio = 0.25 ),
258+ "cfg" : _mcfg (w0 = 48 , wa = 32.54 , wm = 2.32 , group_size = 16 , depth = 15 , se_ratio = 0.25 ),
258259 },
259260 },
260261 "timm-regnety_008" : {
261262 "encoder" : RegNetEncoder ,
262263 "pretrained_settings" : pretrained_settings ["timm-regnety_008" ],
263264 "params" : {
264265 "out_channels" : (3 , 32 , 64 , 128 , 320 , 768 ),
265- "cfg" : _mcfg (w0 = 56 , wa = 38.84 , wm = 2.4 , group_w = 16 , depth = 14 , se_ratio = 0.25 ),
266+ "cfg" : _mcfg (w0 = 56 , wa = 38.84 , wm = 2.4 , group_size = 16 , depth = 14 , se_ratio = 0.25 ),
266267 },
267268 },
268269 "timm-regnety_016" : {
269270 "encoder" : RegNetEncoder ,
270271 "pretrained_settings" : pretrained_settings ["timm-regnety_016" ],
271272 "params" : {
272273 "out_channels" : (3 , 32 , 48 , 120 , 336 , 888 ),
273- "cfg" : _mcfg (w0 = 48 , wa = 20.71 , wm = 2.65 , group_w = 24 , depth = 27 , se_ratio = 0.25 ),
274+ "cfg" : _mcfg (w0 = 48 , wa = 20.71 , wm = 2.65 , group_size = 24 , depth = 27 , se_ratio = 0.25 ),
274275 },
275276 },
276277 "timm-regnety_032" : {
277278 "encoder" : RegNetEncoder ,
278279 "pretrained_settings" : pretrained_settings ["timm-regnety_032" ],
279280 "params" : {
280281 "out_channels" : (3 , 32 , 72 , 216 , 576 , 1512 ),
281- "cfg" : _mcfg (w0 = 80 , wa = 42.63 , wm = 2.66 , group_w = 24 , depth = 21 , se_ratio = 0.25 ),
282+ "cfg" : _mcfg (w0 = 80 , wa = 42.63 , wm = 2.66 , group_size = 24 , depth = 21 , se_ratio = 0.25 ),
282283 },
283284 },
284285 "timm-regnety_040" : {
285286 "encoder" : RegNetEncoder ,
286287 "pretrained_settings" : pretrained_settings ["timm-regnety_040" ],
287288 "params" : {
288289 "out_channels" : (3 , 32 , 128 , 192 , 512 , 1088 ),
289- "cfg" : _mcfg (w0 = 96 , wa = 31.41 , wm = 2.24 , group_w = 64 , depth = 22 , se_ratio = 0.25 ),
290+ "cfg" : _mcfg (w0 = 96 , wa = 31.41 , wm = 2.24 , group_size = 64 , depth = 22 , se_ratio = 0.25 ),
290291 },
291292 },
292293 "timm-regnety_064" : {
293294 "encoder" : RegNetEncoder ,
294295 "pretrained_settings" : pretrained_settings ["timm-regnety_064" ],
295296 "params" : {
296297 "out_channels" : (3 , 32 , 144 , 288 , 576 , 1296 ),
297- "cfg" : _mcfg (w0 = 112 , wa = 33.22 , wm = 2.27 , group_w = 72 , depth = 25 , se_ratio = 0.25 ),
298+ "cfg" : _mcfg (w0 = 112 , wa = 33.22 , wm = 2.27 , group_size = 72 , depth = 25 , se_ratio = 0.25 ),
298299 },
299300 },
300301 "timm-regnety_080" : {
301302 "encoder" : RegNetEncoder ,
302303 "pretrained_settings" : pretrained_settings ["timm-regnety_080" ],
303304 "params" : {
304305 "out_channels" : (3 , 32 , 168 , 448 , 896 , 2016 ),
305- "cfg" : _mcfg (w0 = 192 , wa = 76.82 , wm = 2.19 , group_w = 56 , depth = 17 , se_ratio = 0.25 ),
306+ "cfg" : _mcfg (w0 = 192 , wa = 76.82 , wm = 2.19 , group_size = 56 , depth = 17 , se_ratio = 0.25 ),
306307 },
307308 },
308309 "timm-regnety_120" : {
309310 "encoder" : RegNetEncoder ,
310311 "pretrained_settings" : pretrained_settings ["timm-regnety_120" ],
311312 "params" : {
312313 "out_channels" : (3 , 32 , 224 , 448 , 896 , 2240 ),
313- "cfg" : _mcfg (w0 = 168 , wa = 73.36 , wm = 2.37 , group_w = 112 , depth = 19 , se_ratio = 0.25 ),
314+ "cfg" : _mcfg (w0 = 168 , wa = 73.36 , wm = 2.37 , group_size = 112 , depth = 19 , se_ratio = 0.25 ),
314315 },
315316 },
316317 "timm-regnety_160" : {
317318 "encoder" : RegNetEncoder ,
318319 "pretrained_settings" : pretrained_settings ["timm-regnety_160" ],
319320 "params" : {
320321 "out_channels" : (3 , 32 , 224 , 448 , 1232 , 3024 ),
321- "cfg" : _mcfg (w0 = 200 , wa = 106.23 , wm = 2.48 , group_w = 112 , depth = 18 , se_ratio = 0.25 ),
322+ "cfg" : _mcfg (w0 = 200 , wa = 106.23 , wm = 2.48 , group_size = 112 , depth = 18 , se_ratio = 0.25 ),
322323 },
323324 },
324325 "timm-regnety_320" : {
325326 "encoder" : RegNetEncoder ,
326327 "pretrained_settings" : pretrained_settings ["timm-regnety_320" ],
327328 "params" : {
328329 "out_channels" : (3 , 32 , 232 , 696 , 1392 , 3712 ),
329- "cfg" : _mcfg (w0 = 232 , wa = 115.89 , wm = 2.53 , group_w = 232 , depth = 20 , se_ratio = 0.25 ),
330+ "cfg" : _mcfg (w0 = 232 , wa = 115.89 , wm = 2.53 , group_size = 232 , depth = 20 , se_ratio = 0.25 ),
330331 },
331332 },
332333}
0 commit comments