@@ -46,12 +46,12 @@ class SelectiveKernelBasic(nn.Module):
4646 expansion = 1
4747
4848 def __init__ (self , inplanes , planes , stride = 1 , downsample = None , cardinality = 1 , base_width = 64 ,
49- sk_kwargs = None , reduce_first = 1 , dilation = 1 , first_dilation = None ,
50- drop_block = None , drop_path = None , act_layer = nn . ReLU , norm_layer = nn . BatchNorm2d , attn_layer = None ):
49+ sk_kwargs = None , reduce_first = 1 , dilation = 1 , first_dilation = None , act_layer = nn . ReLU ,
50+ norm_layer = nn . BatchNorm2d , attn_layer = None , aa_layer = None , drop_block = None , drop_path = None ):
5151 super (SelectiveKernelBasic , self ).__init__ ()
5252
5353 sk_kwargs = sk_kwargs or {}
54- conv_kwargs = dict (drop_block = drop_block , act_layer = act_layer , norm_layer = norm_layer )
54+ conv_kwargs = dict (drop_block = drop_block , act_layer = act_layer , norm_layer = norm_layer , aa_layer = aa_layer )
5555 assert cardinality == 1 , 'BasicBlock only supports cardinality of 1'
5656 assert base_width == 64 , 'BasicBlock doest not support changing base width'
5757 first_planes = planes // reduce_first
@@ -94,11 +94,12 @@ class SelectiveKernelBottleneck(nn.Module):
9494
9595 def __init__ (self , inplanes , planes , stride = 1 , downsample = None ,
9696 cardinality = 1 , base_width = 64 , sk_kwargs = None , reduce_first = 1 , dilation = 1 , first_dilation = None ,
97- drop_block = None , drop_path = None , act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , attn_layer = None ):
97+ act_layer = nn .ReLU , norm_layer = nn .BatchNorm2d , attn_layer = None , aa_layer = None ,
98+ drop_block = None , drop_path = None ):
9899 super (SelectiveKernelBottleneck , self ).__init__ ()
99100
100101 sk_kwargs = sk_kwargs or {}
101- conv_kwargs = dict (drop_block = drop_block , act_layer = act_layer , norm_layer = norm_layer )
102+ conv_kwargs = dict (drop_block = drop_block , act_layer = act_layer , norm_layer = norm_layer , aa_layer = aa_layer )
102103 width = int (math .floor (planes * (base_width / 64 )) * cardinality )
103104 first_planes = width // reduce_first
104105 outplanes = planes * self .expansion
0 commit comments