@@ -51,23 +51,99 @@ def forward(self, *features):
5151 return super ().forward (features [- 1 ])
5252
5353
54+ class DeepLabV3PlusDecoder (nn .Module ):
55+ def __init__ (
56+ self ,
57+ encoder_channels ,
58+ out_channels = 256 ,
59+ atrous_rates = (12 , 24 , 36 ),
60+ output_stride = 16 ,
61+ ):
62+ super ().__init__ ()
63+ if output_stride not in {8 , 16 }:
64+ raise ValueError ("Output stride should be 8 or 16, got {}." .format (output_stride ))
65+
66+ self .out_channels = out_channels
67+ self .output_stride = output_stride
68+
69+ self .aspp = nn .Sequential (
70+ ASPP (encoder_channels [- 1 ], out_channels , atrous_rates , separable = True ),
71+ SeparableConv2d (out_channels , out_channels , kernel_size = 3 , padding = 1 , bias = False ),
72+ nn .BatchNorm2d (out_channels ),
73+ nn .ReLU (),
74+ )
75+
76+ scale_factor = 2 if output_stride == 8 else 4
77+ self .up = nn .UpsamplingBilinear2d (scale_factor = scale_factor )
78+
79+ highres_in_channels = encoder_channels [- 4 ]
80+ highres_out_channels = 48 # proposed by authors of paper
81+ self .block1 = nn .Sequential (
82+ nn .Conv2d (highres_in_channels , highres_out_channels , kernel_size = 1 , bias = False ),
83+ nn .BatchNorm2d (highres_out_channels ),
84+ nn .ReLU (),
85+ )
86+ self .block2 = nn .Sequential (
87+ SeparableConv2d (
88+ highres_out_channels + out_channels ,
89+ out_channels ,
90+ kernel_size = 3 ,
91+ padding = 1 ,
92+ bias = False ,
93+ ),
94+ nn .BatchNorm2d (out_channels ),
95+ nn .ReLU (),
96+ )
97+
98+ def forward (self , * features ):
99+ aspp_features = self .aspp (features [- 1 ])
100+ aspp_features = self .up (aspp_features )
101+ high_res_features = self .block1 (features [- 4 ])
102+ concat_features = torch .cat ([aspp_features , high_res_features ], dim = 1 )
103+ fused_features = self .block2 (concat_features )
104+ return fused_features
105+
106+
54107class ASPPConv (nn .Sequential ):
55108 def __init__ (self , in_channels , out_channels , dilation ):
56- modules = [
57- nn .Conv2d (in_channels , out_channels , 3 , padding = dilation , dilation = dilation , bias = False ),
109+ super ().__init__ (
110+ nn .Conv2d (
111+ in_channels ,
112+ out_channels ,
113+ kernel_size = 3 ,
114+ padding = dilation ,
115+ dilation = dilation ,
116+ bias = False ,
117+ ),
118+ nn .BatchNorm2d (out_channels ),
119+ nn .ReLU (),
120+ )
121+
122+
123+ class ASPPSeparableConv (nn .Sequential ):
124+ def __init__ (self , in_channels , out_channels , dilation ):
125+ super ().__init__ (
126+ SeparableConv2d (
127+ in_channels ,
128+ out_channels ,
129+ kernel_size = 3 ,
130+ padding = dilation ,
131+ dilation = dilation ,
132+ bias = False ,
133+ ),
58134 nn .BatchNorm2d (out_channels ),
59- nn .ReLU ()
60- ]
61- super (ASPPConv , self ).__init__ (* modules )
135+ nn .ReLU (),
136+ )
62137
63138
64139class ASPPPooling (nn .Sequential ):
65140 def __init__ (self , in_channels , out_channels ):
66- super (ASPPPooling , self ).__init__ (
141+ super ().__init__ (
67142 nn .AdaptiveAvgPool2d (1 ),
68- nn .Conv2d (in_channels , out_channels , 1 , bias = False ),
143+ nn .Conv2d (in_channels , out_channels , kernel_size = 1 , bias = False ),
69144 nn .BatchNorm2d (out_channels ),
70- nn .ReLU ())
145+ nn .ReLU (),
146+ )
71147
72148 def forward (self , x ):
73149 size = x .shape [- 2 :]
@@ -77,31 +153,68 @@ def forward(self, x):
77153
78154
79155class ASPP (nn .Module ):
80- def __init__ (self , in_channels , out_channels , atrous_rates ):
156+ def __init__ (self , in_channels , out_channels , atrous_rates , separable = False ):
81157 super (ASPP , self ).__init__ ()
82158 modules = []
83- modules .append (nn .Sequential (
84- nn .Conv2d (in_channels , out_channels , 1 , bias = False ),
85- nn .BatchNorm2d (out_channels ),
86- nn .ReLU ()))
159+ modules .append (
160+ nn .Sequential (
161+ nn .Conv2d (in_channels , out_channels , 1 , bias = False ),
162+ nn .BatchNorm2d (out_channels ),
163+ nn .ReLU (),
164+ )
165+ )
87166
88167 rate1 , rate2 , rate3 = tuple (atrous_rates )
89- modules .append (ASPPConv (in_channels , out_channels , rate1 ))
90- modules .append (ASPPConv (in_channels , out_channels , rate2 ))
91- modules .append (ASPPConv (in_channels , out_channels , rate3 ))
168+ ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
169+
170+ modules .append (ASPPConvModule (in_channels , out_channels , rate1 ))
171+ modules .append (ASPPConvModule (in_channels , out_channels , rate2 ))
172+ modules .append (ASPPConvModule (in_channels , out_channels , rate3 ))
92173 modules .append (ASPPPooling (in_channels , out_channels ))
93174
94175 self .convs = nn .ModuleList (modules )
95176
96177 self .project = nn .Sequential (
97- nn .Conv2d (5 * out_channels , out_channels , 1 , bias = False ),
178+ nn .Conv2d (5 * out_channels , out_channels , kernel_size = 1 , bias = False ),
98179 nn .BatchNorm2d (out_channels ),
99180 nn .ReLU (),
100- nn .Dropout (0.5 ))
181+ nn .Dropout (0.5 ),
182+ )
101183
102184 def forward (self , x ):
103185 res = []
104186 for conv in self .convs :
105187 res .append (conv (x ))
106188 res = torch .cat (res , dim = 1 )
107189 return self .project (res )
190+
191+
192+ class SeparableConv2d (nn .Sequential ):
193+
194+ def __init__ (
195+ self ,
196+ in_channels ,
197+ out_channels ,
198+ kernel_size ,
199+ stride = 1 ,
200+ padding = 0 ,
201+ dilation = 1 ,
202+ bias = True ,
203+ ):
204+ dephtwise_conv = nn .Conv2d (
205+ in_channels ,
206+ in_channels ,
207+ kernel_size ,
208+ stride = stride ,
209+ padding = padding ,
210+ dilation = dilation ,
211+ groups = in_channels ,
212+ bias = False ,
213+ )
214+ pointwise_conv = nn .Conv2d (
215+ in_channels ,
216+ out_channels ,
217+ kernel_size = 1 ,
218+ bias = bias ,
219+ )
220+ super ().__init__ (dephtwise_conv , pointwise_conv )
0 commit comments