Skip to content

Commit 8d8677e

Browse files
committed
Fix #139. Broken SKResNets after BlurPool addition, as a plus, SKResNets support AA now too.
1 parent 353a79a commit 8d8677e

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

timm/models/layers/conv_bn_act.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010
class ConvBnAct(nn.Module):
1111
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1,
12-
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
12+
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
1313
super(ConvBnAct, self).__init__()
1414
padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block
15+
use_aa = aa_layer is not None
1516
self.conv = nn.Conv2d(
16-
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
17+
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1 if use_aa else stride,
1718
padding=padding, dilation=dilation, groups=groups, bias=False)
1819
self.bn = norm_layer(out_channels)
20+
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
1921
self.drop_block = drop_block
2022
if act_layer is not None:
2123
self.act = act_layer(inplace=True)
@@ -29,4 +31,6 @@ def forward(self, x):
2931
x = self.drop_block(x)
3032
if self.act is not None:
3133
x = self.act(x)
34+
if self.aa is not None:
35+
x = self.aa(x)
3236
return x

timm/models/layers/selective_kernel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class SelectiveKernelConv(nn.Module):
5252

5353
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
5454
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
55-
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
55+
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
5656
""" Selective Kernel Convolution Module
5757
5858
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
@@ -98,7 +98,8 @@ def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilati
9898
groups = min(out_channels, groups)
9999

100100
conv_kwargs = dict(
101-
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
101+
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer,
102+
aa_layer=aa_layer)
102103
self.paths = nn.ModuleList([
103104
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
104105
for k, d in zip(kernel_size, dilation)])

timm/models/sknet.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ class SelectiveKernelBasic(nn.Module):
4646
expansion = 1
4747

4848
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
49-
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
50-
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
49+
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU,
50+
norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
5151
super(SelectiveKernelBasic, self).__init__()
5252

5353
sk_kwargs = sk_kwargs or {}
54-
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
54+
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
5555
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
5656
assert base_width == 64, 'BasicBlock doest not support changing base width'
5757
first_planes = planes // reduce_first
@@ -94,11 +94,12 @@ class SelectiveKernelBottleneck(nn.Module):
9494

9595
def __init__(self, inplanes, planes, stride=1, downsample=None,
9696
cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
97-
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
97+
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
98+
drop_block=None, drop_path=None):
9899
super(SelectiveKernelBottleneck, self).__init__()
99100

100101
sk_kwargs = sk_kwargs or {}
101-
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
102+
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
102103
width = int(math.floor(planes * (base_width / 64)) * cardinality)
103104
first_planes = width // reduce_first
104105
outplanes = planes * self.expansion

timm/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.24'
1+
__version__ = '0.1.26'

0 commit comments

Comments
 (0)