Skip to content

Commit e8045e7

Browse files
committed
Fix BatchNorm for ResNetV2 non GN models, add more ResNetV2 model defs for future experimentation, fix zero_init of last residual for pre-act.
1 parent 02aaa78 commit e8045e7

File tree

1 file changed

+91
-21
lines changed

1 file changed

+91
-21
lines changed

timm/models/resnetv2.py

Lines changed: 91 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3939
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
4040
from .registry import register_model
41-
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
41+
from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\
42+
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
4243

4344

4445
def _cfg(url='', **kwargs):
@@ -107,6 +108,16 @@ def _cfg(url='', **kwargs):
107108
interpolation='bicubic'),
108109
'resnetv2_50d': _cfg(
109110
interpolation='bicubic', first_conv='stem.conv1'),
111+
'resnetv2_50t': _cfg(
112+
interpolation='bicubic', first_conv='stem.conv1'),
113+
'resnetv2_101': _cfg(
114+
interpolation='bicubic'),
115+
'resnetv2_101d': _cfg(
116+
interpolation='bicubic', first_conv='stem.conv1'),
117+
'resnetv2_152': _cfg(
118+
interpolation='bicubic'),
119+
'resnetv2_152d': _cfg(
120+
interpolation='bicubic', first_conv='stem.conv1'),
110121
}
111122

112123

@@ -152,8 +163,8 @@ def __init__(
152163
self.conv3 = conv_layer(mid_chs, out_chs, 1)
153164
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
154165

155-
def zero_init_last_bn(self):
156-
nn.init.zeros_(self.norm3.weight)
166+
def zero_init_last(self):
167+
nn.init.zeros_(self.conv3.weight)
157168

158169
def forward(self, x):
159170
x_preact = self.norm1(x)
@@ -201,7 +212,7 @@ def __init__(
201212
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
202213
self.act3 = act_layer(inplace=True)
203214

204-
def zero_init_last_bn(self):
215+
def zero_init_last(self):
205216
nn.init.zeros_(self.norm3.weight)
206217

207218
def forward(self, x):
@@ -284,17 +295,20 @@ def create_resnetv2_stem(
284295
in_chs, out_chs=64, stem_type='', preact=True,
285296
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
286297
stem = OrderedDict()
287-
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
298+
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
288299

289300
# NOTE conv padding mode can be changed by overriding the conv_layer def
290-
if 'deep' in stem_type:
301+
if any([s in stem_type for s in ('deep', 'tiered')]):
291302
# A 3 deep 3x3 conv stack as in ResNet V1D models
292-
mid_chs = out_chs // 2
293-
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
294-
stem['norm1'] = norm_layer(mid_chs)
295-
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
296-
stem['norm2'] = norm_layer(mid_chs)
297-
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
303+
if 'tiered' in stem_type:
304+
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
305+
else:
306+
stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets
307+
stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2)
308+
stem['norm1'] = norm_layer(stem_chs[0])
309+
stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
310+
stem['norm2'] = norm_layer(stem_chs[1])
311+
stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1)
298312
if not preact:
299313
stem['norm3'] = norm_layer(out_chs)
300314
else:
@@ -326,7 +340,7 @@ def __init__(
326340
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
327341
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
328342
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
329-
drop_rate=0., drop_path_rate=0., zero_init_last_bn=True):
343+
drop_rate=0., drop_path_rate=0., zero_init_last=True):
330344
super().__init__()
331345
self.num_classes = num_classes
332346
self.drop_rate = drop_rate
@@ -364,10 +378,10 @@ def __init__(
364378
self.head = ClassifierHead(
365379
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
366380

367-
self.init_weights(zero_init_last_bn=zero_init_last_bn)
381+
self.init_weights(zero_init_last=zero_init_last)
368382

369-
def init_weights(self, zero_init_last_bn=True):
370-
named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self)
383+
def init_weights(self, zero_init_last=True):
384+
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
371385

372386
@torch.jit.ignore()
373387
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
@@ -393,7 +407,7 @@ def forward(self, x):
393407
return x
394408

395409

396-
def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
410+
def _init_weights(module: nn.Module, name: str = '', zero_init_last=True):
397411
if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
398412
nn.init.normal_(module.weight, mean=0.0, std=0.01)
399413
nn.init.zeros_(module.bias)
@@ -404,8 +418,8 @@ def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
404418
elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
405419
nn.init.ones_(module.weight)
406420
nn.init.zeros_(module.bias)
407-
elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'):
408-
module.zero_init_last_bn()
421+
elif zero_init_last and hasattr(module, 'zero_init_last'):
422+
module.zero_init_last()
409423

410424

411425
@torch.no_grad()
@@ -570,12 +584,68 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
570584
def resnetv2_50(pretrained=False, **kwargs):
571585
return _create_resnetv2(
572586
'resnetv2_50', pretrained=pretrained,
573-
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs)
587+
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
574588

575589

576590
@register_model
577591
def resnetv2_50d(pretrained=False, **kwargs):
578592
return _create_resnetv2(
579593
'resnetv2_50d', pretrained=pretrained,
580-
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d,
594+
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
581595
stem_type='deep', avg_down=True, **kwargs)
596+
597+
598+
@register_model
599+
def resnetv2_50t(pretrained=False, **kwargs):
600+
return _create_resnetv2(
601+
'resnetv2_50t', pretrained=pretrained,
602+
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
603+
stem_type='tiered', avg_down=True, **kwargs)
604+
605+
606+
@register_model
607+
def resnetv2_101(pretrained=False, **kwargs):
608+
return _create_resnetv2(
609+
'resnetv2_101', pretrained=pretrained,
610+
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
611+
612+
613+
@register_model
614+
def resnetv2_101d(pretrained=False, **kwargs):
615+
return _create_resnetv2(
616+
'resnetv2_101d', pretrained=pretrained,
617+
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
618+
stem_type='deep', avg_down=True, **kwargs)
619+
620+
621+
@register_model
622+
def resnetv2_152(pretrained=False, **kwargs):
623+
return _create_resnetv2(
624+
'resnetv2_152', pretrained=pretrained,
625+
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
626+
627+
628+
@register_model
629+
def resnetv2_152d(pretrained=False, **kwargs):
630+
return _create_resnetv2(
631+
'resnetv2_152d', pretrained=pretrained,
632+
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
633+
stem_type='deep', avg_down=True, **kwargs)
634+
635+
636+
# @register_model
637+
# def resnetv2_50ebd(pretrained=False, **kwargs):
638+
# # FIXME for testing w/ TPU + PyTorch XLA
639+
# return _create_resnetv2(
640+
# 'resnetv2_50d', pretrained=pretrained,
641+
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
642+
# stem_type='deep', avg_down=True, **kwargs)
643+
#
644+
#
645+
# @register_model
646+
# def resnetv2_50esd(pretrained=False, **kwargs):
647+
# # FIXME for testing w/ TPU + PyTorch XLA
648+
# return _create_resnetv2(
649+
# 'resnetv2_50d', pretrained=pretrained,
650+
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
651+
# stem_type='deep', avg_down=True, **kwargs)

0 commit comments

Comments
 (0)