Skip to content

Commit 2f884a0

Browse files
committed
Add resnest14, resnest26, and two of the abalation grouped resnest50 models
1 parent f4cdc2a commit 2f884a0

File tree

1 file changed

+41
-5
lines changed

1 file changed

+41
-5
lines changed

timm/models/resnest.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
""" ResNeSt Models
22
3-
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
3+
Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955
44
5-
Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt
5+
Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang
66
77
Modified for torchscript compat, and consistency with timm by Ross Wightman
88
"""
@@ -31,8 +31,10 @@ def _cfg(url='', **kwargs):
3131
}
3232

3333
default_cfgs = {
34+
'resnest14d': _cfg(
35+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'),
3436
'resnest26d': _cfg(
35-
url=''),
37+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'),
3638
'resnest50d': _cfg(
3739
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'),
3840
'resnest101e': _cfg(
@@ -41,6 +43,12 @@ def _cfg(url='', **kwargs):
4143
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)),
4244
'resnest269e': _cfg(
4345
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)),
46+
'resnest50d_4s2x40d': _cfg(
47+
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth',
48+
interpolation='bicubic'),
49+
'resnest50d_1s4x24d': _cfg(
50+
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_1s4x24d-d4a4f76f.pth',
51+
interpolation='bicubic')
4452
}
4553

4654

@@ -78,7 +86,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None,
7886
if self.radix >= 1:
7987
self.conv2 = SplitAttnConv2d(
8088
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
81-
dilation=first_dilation, groups=cardinality, norm_layer=norm_layer, drop_block=drop_block)
89+
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
8290
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
8391
self.drop_block2 = None
8492
self.act2 = None
@@ -135,9 +143,24 @@ def forward(self, x):
135143
return out
136144

137145

146+
@register_model
147+
def resnest14d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
148+
""" ResNeSt-14d model. Weights ported from GluonCV.
149+
"""
150+
default_cfg = default_cfgs['resnest14d']
151+
model = ResNet(
152+
ResNestBottleneck, [1, 1, 1, 1], num_classes=num_classes, in_chans=in_chans,
153+
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
154+
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
155+
model.default_cfg = default_cfg
156+
if pretrained:
157+
load_pretrained(model, default_cfg, num_classes, in_chans)
158+
return model
159+
160+
138161
@register_model
139162
def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
140-
""" ResNeSt-26d model.
163+
""" ResNeSt-26d model. Weights ported from GluonCV.
141164
"""
142165
default_cfg = default_cfgs['resnest26d']
143166
model = ResNet(
@@ -212,3 +235,16 @@ def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
212235
if pretrained:
213236
load_pretrained(model, default_cfg, num_classes, in_chans)
214237
return model
238+
239+
240+
@register_model
241+
def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
242+
default_cfg = default_cfgs['resnest50d_1s4x24d']
243+
model = ResNet(
244+
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
245+
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
246+
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
247+
model.default_cfg = default_cfg
248+
if pretrained:
249+
load_pretrained(model, default_cfg, num_classes, in_chans)
250+
return model

0 commit comments

Comments
 (0)