3838from timm .data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
3939from .helpers import build_model_with_cfg , named_apply , adapt_input_conv
4040from .registry import register_model
41- from .layers import GroupNormAct , ClassifierHead , DropPath , AvgPool2dSame , create_pool2d , StdConv2d , create_conv2d
41+ from .layers import GroupNormAct , BatchNormAct2d , EvoNormBatch2d , EvoNormSample2d ,\
42+ ClassifierHead , DropPath , AvgPool2dSame , create_pool2d , StdConv2d , create_conv2d
4243
4344
4445def _cfg (url = '' , ** kwargs ):
@@ -107,6 +108,16 @@ def _cfg(url='', **kwargs):
107108 interpolation = 'bicubic' ),
108109 'resnetv2_50d' : _cfg (
109110 interpolation = 'bicubic' , first_conv = 'stem.conv1' ),
111+ 'resnetv2_50t' : _cfg (
112+ interpolation = 'bicubic' , first_conv = 'stem.conv1' ),
113+ 'resnetv2_101' : _cfg (
114+ interpolation = 'bicubic' ),
115+ 'resnetv2_101d' : _cfg (
116+ interpolation = 'bicubic' , first_conv = 'stem.conv1' ),
117+ 'resnetv2_152' : _cfg (
118+ interpolation = 'bicubic' ),
119+ 'resnetv2_152d' : _cfg (
120+ interpolation = 'bicubic' , first_conv = 'stem.conv1' ),
110121}
111122
112123
@@ -152,8 +163,8 @@ def __init__(
152163 self .conv3 = conv_layer (mid_chs , out_chs , 1 )
153164 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0 else nn .Identity ()
154165
155- def zero_init_last_bn (self ):
156- nn .init .zeros_ (self .norm3 .weight )
166+ def zero_init_last (self ):
167+ nn .init .zeros_ (self .conv3 .weight )
157168
158169 def forward (self , x ):
159170 x_preact = self .norm1 (x )
@@ -201,7 +212,7 @@ def __init__(
201212 self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0 else nn .Identity ()
202213 self .act3 = act_layer (inplace = True )
203214
204- def zero_init_last_bn (self ):
215+ def zero_init_last (self ):
205216 nn .init .zeros_ (self .norm3 .weight )
206217
207218 def forward (self , x ):
@@ -284,17 +295,20 @@ def create_resnetv2_stem(
284295 in_chs , out_chs = 64 , stem_type = '' , preact = True ,
285296 conv_layer = StdConv2d , norm_layer = partial (GroupNormAct , num_groups = 32 )):
286297 stem = OrderedDict ()
287- assert stem_type in ('' , 'fixed' , 'same' , 'deep' , 'deep_fixed' , 'deep_same' )
298+ assert stem_type in ('' , 'fixed' , 'same' , 'deep' , 'deep_fixed' , 'deep_same' , 'tiered' )
288299
289300 # NOTE conv padding mode can be changed by overriding the conv_layer def
290- if 'deep' in stem_type :
301+ if any ([ s in stem_type for s in ( 'deep' , 'tiered' )]) :
291302 # A 3 deep 3x3 conv stack as in ResNet V1D models
292- mid_chs = out_chs // 2
293- stem ['conv1' ] = conv_layer (in_chs , mid_chs , kernel_size = 3 , stride = 2 )
294- stem ['norm1' ] = norm_layer (mid_chs )
295- stem ['conv2' ] = conv_layer (mid_chs , mid_chs , kernel_size = 3 , stride = 1 )
296- stem ['norm2' ] = norm_layer (mid_chs )
297- stem ['conv3' ] = conv_layer (mid_chs , out_chs , kernel_size = 3 , stride = 1 )
303+ if 'tiered' in stem_type :
304+ stem_chs = (3 * out_chs // 8 , out_chs // 2 ) # 'T' resnets in resnet.py
305+ else :
306+ stem_chs = (out_chs // 2 , out_chs // 2 ) # 'D' ResNets
307+ stem ['conv1' ] = conv_layer (in_chs , stem_chs [0 ], kernel_size = 3 , stride = 2 )
308+ stem ['norm1' ] = norm_layer (stem_chs [0 ])
309+ stem ['conv2' ] = conv_layer (stem_chs [0 ], stem_chs [1 ], kernel_size = 3 , stride = 1 )
310+ stem ['norm2' ] = norm_layer (stem_chs [1 ])
311+ stem ['conv3' ] = conv_layer (stem_chs [1 ], out_chs , kernel_size = 3 , stride = 1 )
298312 if not preact :
299313 stem ['norm3' ] = norm_layer (out_chs )
300314 else :
@@ -326,7 +340,7 @@ def __init__(
326340 num_classes = 1000 , in_chans = 3 , global_pool = 'avg' , output_stride = 32 ,
327341 width_factor = 1 , stem_chs = 64 , stem_type = '' , avg_down = False , preact = True ,
328342 act_layer = nn .ReLU , conv_layer = StdConv2d , norm_layer = partial (GroupNormAct , num_groups = 32 ),
329- drop_rate = 0. , drop_path_rate = 0. , zero_init_last_bn = True ):
343+ drop_rate = 0. , drop_path_rate = 0. , zero_init_last = True ):
330344 super ().__init__ ()
331345 self .num_classes = num_classes
332346 self .drop_rate = drop_rate
@@ -364,10 +378,10 @@ def __init__(
364378 self .head = ClassifierHead (
365379 self .num_features , num_classes , pool_type = global_pool , drop_rate = self .drop_rate , use_conv = True )
366380
367- self .init_weights (zero_init_last_bn = zero_init_last_bn )
381+ self .init_weights (zero_init_last = zero_init_last )
368382
369- def init_weights (self , zero_init_last_bn = True ):
370- named_apply (partial (_init_weights , zero_init_last_bn = zero_init_last_bn ), self )
383+ def init_weights (self , zero_init_last = True ):
384+ named_apply (partial (_init_weights , zero_init_last = zero_init_last ), self )
371385
372386 @torch .jit .ignore ()
373387 def load_pretrained (self , checkpoint_path , prefix = 'resnet/' ):
@@ -393,7 +407,7 @@ def forward(self, x):
393407 return x
394408
395409
396- def _init_weights (module : nn .Module , name : str = '' , zero_init_last_bn = True ):
410+ def _init_weights (module : nn .Module , name : str = '' , zero_init_last = True ):
397411 if isinstance (module , nn .Linear ) or ('head.fc' in name and isinstance (module , nn .Conv2d )):
398412 nn .init .normal_ (module .weight , mean = 0.0 , std = 0.01 )
399413 nn .init .zeros_ (module .bias )
@@ -404,8 +418,8 @@ def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
404418 elif isinstance (module , (nn .BatchNorm2d , nn .LayerNorm , nn .GroupNorm )):
405419 nn .init .ones_ (module .weight )
406420 nn .init .zeros_ (module .bias )
407- elif zero_init_last_bn and hasattr (module , 'zero_init_last_bn ' ):
408- module .zero_init_last_bn ()
421+ elif zero_init_last and hasattr (module , 'zero_init_last ' ):
422+ module .zero_init_last ()
409423
410424
411425@torch .no_grad ()
@@ -570,12 +584,68 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
570584def resnetv2_50 (pretrained = False , ** kwargs ):
571585 return _create_resnetv2 (
572586 'resnetv2_50' , pretrained = pretrained ,
573- layers = [3 , 4 , 6 , 3 ], conv_layer = create_conv2d , norm_layer = nn . BatchNorm2d , ** kwargs )
587+ layers = [3 , 4 , 6 , 3 ], conv_layer = create_conv2d , norm_layer = BatchNormAct2d , ** kwargs )
574588
575589
576590@register_model
577591def resnetv2_50d (pretrained = False , ** kwargs ):
578592 return _create_resnetv2 (
579593 'resnetv2_50d' , pretrained = pretrained ,
580- layers = [3 , 4 , 6 , 3 ], conv_layer = create_conv2d , norm_layer = nn . BatchNorm2d ,
594+ layers = [3 , 4 , 6 , 3 ], conv_layer = create_conv2d , norm_layer = BatchNormAct2d ,
581595 stem_type = 'deep' , avg_down = True , ** kwargs )
596+
597+
598+ @register_model
599+ def resnetv2_50t (pretrained = False , ** kwargs ):
600+ return _create_resnetv2 (
601+ 'resnetv2_50t' , pretrained = pretrained ,
602+ layers = [3 , 4 , 6 , 3 ], conv_layer = create_conv2d , norm_layer = BatchNormAct2d ,
603+ stem_type = 'tiered' , avg_down = True , ** kwargs )
604+
605+
606+ @register_model
607+ def resnetv2_101 (pretrained = False , ** kwargs ):
608+ return _create_resnetv2 (
609+ 'resnetv2_101' , pretrained = pretrained ,
610+ layers = [3 , 4 , 23 , 3 ], conv_layer = create_conv2d , norm_layer = BatchNormAct2d , ** kwargs )
611+
612+
613+ @register_model
614+ def resnetv2_101d (pretrained = False , ** kwargs ):
615+ return _create_resnetv2 (
616+ 'resnetv2_101d' , pretrained = pretrained ,
617+ layers = [3 , 4 , 23 , 3 ], conv_layer = create_conv2d , norm_layer = BatchNormAct2d ,
618+ stem_type = 'deep' , avg_down = True , ** kwargs )
619+
620+
621+ @register_model
622+ def resnetv2_152 (pretrained = False , ** kwargs ):
623+ return _create_resnetv2 (
624+ 'resnetv2_152' , pretrained = pretrained ,
625+ layers = [3 , 8 , 36 , 3 ], conv_layer = create_conv2d , norm_layer = BatchNormAct2d , ** kwargs )
626+
627+
628+ @register_model
629+ def resnetv2_152d (pretrained = False , ** kwargs ):
630+ return _create_resnetv2 (
631+ 'resnetv2_152d' , pretrained = pretrained ,
632+ layers = [3 , 8 , 36 , 3 ], conv_layer = create_conv2d , norm_layer = BatchNormAct2d ,
633+ stem_type = 'deep' , avg_down = True , ** kwargs )
634+
635+
636+ # @register_model
637+ # def resnetv2_50ebd(pretrained=False, **kwargs):
638+ # # FIXME for testing w/ TPU + PyTorch XLA
639+ # return _create_resnetv2(
640+ # 'resnetv2_50d', pretrained=pretrained,
641+ # layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
642+ # stem_type='deep', avg_down=True, **kwargs)
643+ #
644+ #
645+ # @register_model
646+ # def resnetv2_50esd(pretrained=False, **kwargs):
647+ # # FIXME for testing w/ TPU + PyTorch XLA
648+ # return _create_resnetv2(
649+ # 'resnetv2_50d', pretrained=pretrained,
650+ # layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
651+ # stem_type='deep', avg_down=True, **kwargs)
0 commit comments