Skip to content

Commit c4ca016

Browse files
authored
Merge pull request #145 from rwightman/resnest
ResNeSt
2 parents 5bd1ad1 + 208e791 commit c4ca016

File tree

4 files changed

+353
-35
lines changed

4 files changed

+353
-35
lines changed

README.md

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## What's New
44

5+
### May 12, 2020
6+
* Add ResNeSt models (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955))
7+
58
### May 3, 2020
69
* Pruned EfficientNet B1, B2, and B3 (https://arxiv.org/abs/2002.08258) contributed by [Yonathan Aflalo](https://github.com/yoniaflalo)
710

@@ -70,41 +73,6 @@
7073
* Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section)
7174
* Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs.
7275

73-
### Dec 30, 2019
74-
* Merge [Dushyant Mehta's](https://github.com/mehtadushy) PR for SelecSLS (Selective Short and Long Range Skip Connections) networks. Good GPU memory consumption and throughput. Original: https://github.com/mehtadushy/SelecSLS-Pytorch
75-
76-
### Dec 28, 2019
77-
* Add new model weights and training hparams (see Training Hparams section)
78-
* `efficientnet_b3` - 81.5 top-1, 95.7 top-5 at default res/crop, 81.9, 95.8 at 320x320 1.0 crop-pct
79-
* trained with RandAugment, ended up with an interesting but less than perfect result (see training section)
80-
* `seresnext26d_32x4d`- 77.6 top-1, 93.6 top-5
81-
* deep stem (32, 32, 64), avgpool downsample
82-
* stem/dowsample from bag-of-tricks paper
83-
* `seresnext26t_32x4d`- 78.0 top-1, 93.7 top-5
84-
* deep tiered stem (24, 48, 64), avgpool downsample (a modified 'D' variant)
85-
* stem sizing mods from Jeremy Howard and fastai devs discussing ResNet architecture experiments
86-
87-
### Dec 23, 2019
88-
* Add RandAugment trained MixNet-XL weights with 80.48 top-1.
89-
* `--dist-bn` argument added to train.py, will distribute BN stats between nodes after each train epoch, before eval
90-
91-
### Dec 4, 2019
92-
* Added weights from the first training from scratch of an EfficientNet (B2) with my new RandAugment implementation. Much better than my previous B2 and very close to the official AdvProp ones (80.4 top-1, 95.08 top-5).
93-
94-
### Nov 29, 2019
95-
* Brought EfficientNet and MobileNetV3 up to date with my https://github.com/rwightman/gen-efficientnet-pytorch code. Torchscript and ONNX export compat excluded.
96-
* AdvProp weights added
97-
* Official TF MobileNetv3 weights added
98-
* EfficientNet and MobileNetV3 hook based 'feature extraction' classes added. Will serve as basis for using models as backbones in obj detection/segmentation tasks. Lots more to be done here...
99-
* HRNet classification models and weights added from https://github.com/HRNet/HRNet-Image-Classification
100-
* Consistency in global pooling, `reset_classifer`, and `forward_features` across models
101-
* `forward_features` always returns unpooled feature maps now
102-
* Reasonable chance I broke something... let me know
103-
104-
### Nov 22, 2019
105-
* Add ImageNet training RandAugment implementation alongside AutoAugment. PyTorch Transform compatible format, using PIL. Currently training two EfficientNet models from scratch with promising results... will update.
106-
* `drop-connect` cmd line arg finally added to `train.py`, no need to hack model fns. Works for efficientnet/mobilenetv3 based models, ignored otherwise.
107-
10876
## Introduction
10977

11078
For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others.
@@ -130,6 +98,7 @@ Included models:
13098
* Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/)
13199
* Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169)
132100
* Selective Kernel (SK) Nets (https://arxiv.org/abs/1903.06586)
101+
* ResNeSt (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955)
133102
* DLA
134103
* Original (https://github.com/ucbdrive/dla, https://arxiv.org/abs/1707.06484)
135104
* Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169)

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .hrnet import *
1919
from .sknet import *
2020
from .tresnet import *
21+
from .resnest import *
2122

2223
from .registry import *
2324
from .factory import create_model

timm/models/layers/split_attn.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
""" Split Attention Conv2d (for ResNeSt Models)
2+
3+
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
4+
5+
Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
6+
7+
Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
8+
"""
9+
import torch
10+
import torch.nn.functional as F
11+
from torch import nn
12+
13+
14+
class RadixSoftmax(nn.Module):
15+
def __init__(self, radix, cardinality):
16+
super(RadixSoftmax, self).__init__()
17+
self.radix = radix
18+
self.cardinality = cardinality
19+
20+
def forward(self, x):
21+
batch = x.size(0)
22+
if self.radix > 1:
23+
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
24+
x = F.softmax(x, dim=1)
25+
x = x.reshape(batch, -1)
26+
else:
27+
x = torch.sigmoid(x)
28+
return x
29+
30+
31+
class SplitAttnConv2d(nn.Module):
32+
"""Split-Attention Conv2d
33+
"""
34+
def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0,
35+
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
36+
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
37+
super(SplitAttnConv2d, self).__init__()
38+
self.radix = radix
39+
self.cardinality = groups
40+
self.channels = channels
41+
mid_chs = channels * radix
42+
attn_chs = max(in_channels * radix // reduction_factor, 32)
43+
self.conv = nn.Conv2d(
44+
in_channels, mid_chs, kernel_size, stride, padding, dilation,
45+
groups=groups * radix, bias=bias, **kwargs)
46+
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
47+
self.act0 = act_layer(inplace=True)
48+
self.fc1 = nn.Conv2d(channels, attn_chs, 1, groups=self.cardinality)
49+
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
50+
self.act1 = act_layer(inplace=True)
51+
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=self.cardinality)
52+
self.drop_block = drop_block
53+
self.rsoftmax = RadixSoftmax(radix, groups)
54+
55+
def forward(self, x):
56+
x = self.conv(x)
57+
if self.bn0 is not None:
58+
x = self.bn0(x)
59+
if self.drop_block is not None:
60+
x = self.drop_block(x)
61+
x = self.act0(x)
62+
63+
B, RC, H, W = x.shape
64+
if self.radix > 1:
65+
x = x.reshape((B, self.radix, RC // self.radix, H, W))
66+
x_gap = torch.sum(x, dim=1)
67+
else:
68+
x_gap = x
69+
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
70+
x_gap = self.fc1(x_gap)
71+
if self.bn1 is not None:
72+
x_gap = self.bn1(x_gap)
73+
x_gap = self.act1(x_gap)
74+
x_attn = self.fc2(x_gap)
75+
76+
x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
77+
if self.radix > 1:
78+
out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
79+
else:
80+
out = x * x_attn
81+
return out.contiguous()

0 commit comments

Comments
 (0)