Skip to content

Commit f4cdc2a

Browse files
committed
Add ResNeSt models
1 parent 8d8677e commit f4cdc2a

File tree

4 files changed

+299
-0
lines changed

4 files changed

+299
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ Included models:
130130
* Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/)
131131
* Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169)
132132
* Selective Kernel (SK) Nets (https://arxiv.org/abs/1903.06586)
133+
* ResNeSt (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955)
133134
* DLA
134135
* Original (https://github.com/ucbdrive/dla, https://arxiv.org/abs/1707.06484)
135136
* Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169)

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .hrnet import *
1919
from .sknet import *
2020
from .tresnet import *
21+
from .resnest import *
2122

2223
from .registry import *
2324
from .factory import create_model

timm/models/layers/split_attn.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
""" Split Attention Conv2d (for ResNeSt Models)
2+
3+
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
4+
5+
Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
6+
7+
Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
8+
"""
9+
import torch
10+
import torch.nn.functional as F
11+
from torch import nn
12+
13+
14+
class RadixSoftmax(nn.Module):
15+
def __init__(self, radix, cardinality):
16+
super(RadixSoftmax, self).__init__()
17+
self.radix = radix
18+
self.cardinality = cardinality
19+
20+
def forward(self, x):
21+
batch = x.size(0)
22+
if self.radix > 1:
23+
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
24+
x = F.softmax(x, dim=1)
25+
x = x.reshape(batch, -1)
26+
else:
27+
x = torch.sigmoid(x)
28+
return x
29+
30+
31+
class SplitAttnConv2d(nn.Module):
32+
"""Split-Attention Conv2d
33+
"""
34+
def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0,
35+
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
36+
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
37+
super(SplitAttnConv2d, self).__init__()
38+
self.radix = radix
39+
self.cardinality = groups
40+
self.channels = channels
41+
mid_chs = channels * radix
42+
attn_chs = max(in_channels * radix // reduction_factor, 32)
43+
self.conv = nn.Conv2d(
44+
in_channels, mid_chs, kernel_size, stride, padding, dilation,
45+
groups=groups * radix, bias=bias, **kwargs)
46+
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
47+
self.act0 = act_layer(inplace=True)
48+
self.fc1 = nn.Conv2d(channels, attn_chs, 1, groups=self.cardinality)
49+
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
50+
self.act1 = act_layer(inplace=True)
51+
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=self.cardinality)
52+
self.drop_block = drop_block
53+
self.rsoftmax = RadixSoftmax(radix, groups)
54+
55+
def forward(self, x):
56+
x = self.conv(x)
57+
if self.bn0 is not None:
58+
x = self.bn0(x)
59+
if self.drop_block is not None:
60+
x = self.drop_block(x)
61+
x = self.act0(x)
62+
63+
B, RC, H, W = x.shape
64+
if self.radix > 1:
65+
x = x.reshape((B, self.radix, RC // self.radix, H, W))
66+
x_gap = torch.sum(x, dim=1)
67+
else:
68+
x_gap = x
69+
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
70+
x_gap = self.fc1(x_gap)
71+
72+
if self.bn1 is not None:
73+
x_gap = self.bn1(x_gap)
74+
x_gap = self.act1(x_gap)
75+
76+
x_attn = self.fc2(x_gap)
77+
x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
78+
79+
if self.radix > 1:
80+
out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
81+
else:
82+
out = x * x_attn
83+
return out.contiguous()

timm/models/resnest.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
""" ResNeSt Models
2+
3+
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
4+
5+
Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt
6+
7+
Modified for torchscript compat, and consistency with timm by Ross Wightman
8+
"""
9+
import math
10+
import torch
11+
import torch.nn.functional as F
12+
from torch import nn
13+
14+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15+
from timm.models.layers import DropBlock2d
16+
from .helpers import load_pretrained
17+
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
18+
from .layers.split_attn import SplitAttnConv2d
19+
from .registry import register_model
20+
from .resnet import ResNet
21+
22+
23+
def _cfg(url='', **kwargs):
24+
return {
25+
'url': url,
26+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
27+
'crop_pct': 0.875, 'interpolation': 'bilinear',
28+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
29+
'first_conv': 'conv1', 'classifier': 'fc',
30+
**kwargs
31+
}
32+
33+
default_cfgs = {
34+
'resnest26d': _cfg(
35+
url=''),
36+
'resnest50d': _cfg(
37+
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'),
38+
'resnest101e': _cfg(
39+
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)),
40+
'resnest200e': _cfg(
41+
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)),
42+
'resnest269e': _cfg(
43+
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)),
44+
}
45+
46+
47+
class ResNestBottleneck(nn.Module):
48+
"""ResNet Bottleneck
49+
"""
50+
# pylint: disable=unused-argument
51+
expansion = 4
52+
53+
def __init__(self, inplanes, planes, stride=1, downsample=None,
54+
radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
55+
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
56+
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
57+
super(ResNestBottleneck, self).__init__()
58+
assert reduce_first == 1 # not supported
59+
assert attn_layer is None # not supported
60+
assert aa_layer is None # TODO not yet supported
61+
assert drop_path is None # TODO not yet supported
62+
63+
group_width = int(planes * (base_width / 64.)) * cardinality
64+
first_dilation = first_dilation or dilation
65+
if avd and (stride > 1 or is_first):
66+
avd_stride = stride
67+
stride = 1
68+
else:
69+
avd_stride = 0
70+
self.radix = radix
71+
72+
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
73+
self.bn1 = norm_layer(group_width)
74+
self.drop_block1 = drop_block if drop_block is not None else None
75+
self.act1 = act_layer(inplace=True)
76+
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
77+
78+
if self.radix >= 1:
79+
self.conv2 = SplitAttnConv2d(
80+
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
81+
dilation=first_dilation, groups=cardinality, norm_layer=norm_layer, drop_block=drop_block)
82+
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
83+
self.drop_block2 = None
84+
self.act2 = None
85+
else:
86+
self.conv2 = nn.Conv2d(
87+
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
88+
dilation=first_dilation, groups=cardinality, bias=False)
89+
self.bn2 = norm_layer(group_width)
90+
self.drop_block2 = drop_block if drop_block is not None else None
91+
self.act2 = act_layer(inplace=True)
92+
self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
93+
94+
self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False)
95+
self.bn3 = norm_layer(planes*4)
96+
self.drop_block3 = drop_block if drop_block is not None else None
97+
self.act3 = act_layer(inplace=True)
98+
self.downsample = downsample
99+
100+
def zero_init_last_bn(self):
101+
nn.init.zeros_(self.bn3.weight)
102+
103+
def forward(self, x):
104+
residual = x
105+
106+
out = self.conv1(x)
107+
out = self.bn1(out)
108+
if self.drop_block1 is not None:
109+
out = self.drop_block1(out)
110+
out = self.act1(out)
111+
112+
if self.avd_first is not None:
113+
out = self.avd_first(out)
114+
115+
out = self.conv2(out)
116+
if self.bn2 is not None:
117+
out = self.bn2(out)
118+
if self.drop_block2 is not None:
119+
out = self.drop_block2(out)
120+
out = self.act2(out)
121+
122+
if self.avd_last is not None:
123+
out = self.avd_last(out)
124+
125+
out = self.conv3(out)
126+
out = self.bn3(out)
127+
if self.drop_block3 is not None:
128+
out = self.drop_block3(out)
129+
130+
if self.downsample is not None:
131+
residual = self.downsample(x)
132+
133+
out += residual
134+
out = self.act3(out)
135+
return out
136+
137+
138+
@register_model
139+
def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
140+
""" ResNeSt-26d model.
141+
"""
142+
default_cfg = default_cfgs['resnest26d']
143+
model = ResNet(
144+
ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
145+
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
146+
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
147+
model.default_cfg = default_cfg
148+
if pretrained:
149+
load_pretrained(model, default_cfg, num_classes, in_chans)
150+
return model
151+
152+
153+
@register_model
154+
def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
155+
""" ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955
156+
Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample.
157+
"""
158+
default_cfg = default_cfgs['resnest50d']
159+
model = ResNet(
160+
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
161+
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
162+
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
163+
model.default_cfg = default_cfg
164+
if pretrained:
165+
load_pretrained(model, default_cfg, num_classes, in_chans)
166+
return model
167+
168+
169+
@register_model
170+
def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
171+
""" ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955
172+
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
173+
"""
174+
default_cfg = default_cfgs['resnest101e']
175+
model = ResNet(
176+
ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans,
177+
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
178+
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
179+
model.default_cfg = default_cfg
180+
if pretrained:
181+
load_pretrained(model, default_cfg, num_classes, in_chans)
182+
return model
183+
184+
185+
@register_model
186+
def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
187+
""" ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955
188+
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
189+
"""
190+
default_cfg = default_cfgs['resnest200e']
191+
model = ResNet(
192+
ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans,
193+
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
194+
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
195+
model.default_cfg = default_cfg
196+
if pretrained:
197+
load_pretrained(model, default_cfg, num_classes, in_chans)
198+
return model
199+
200+
201+
@register_model
202+
def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
203+
""" ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955
204+
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
205+
"""
206+
default_cfg = default_cfgs['resnest269e']
207+
model = ResNet(
208+
ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans,
209+
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
210+
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
211+
model.default_cfg = default_cfg
212+
if pretrained:
213+
load_pretrained(model, default_cfg, num_classes, in_chans)
214+
return model

0 commit comments

Comments
 (0)