@@ -128,6 +128,8 @@ julia> Flux.params(c1) |> length
128128"""
129129function Conv (w:: AbstractArray{T,N} , b = true , σ = identity;
130130 stride = 1 , pad = 0 , dilation = 1 , groups = 1 ) where {T,N}
131+
132+ @assert size (w, N) % groups == 0 " Output channel dimension must be divisible by groups."
131133 stride = expand (Val (N- 2 ), stride)
132134 dilation = expand (Val (N- 2 ), dilation)
133135 pad = calc_padding (Conv, pad, size (w)[1 : N- 2 ], dilation, stride)
@@ -151,12 +153,12 @@ channels from `in` to `out`.
151153
152154Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
153155distribution.
154-
155- See also: [`depthwiseconvfilter`](@ref)
156156"""
157157function convfilter (filter:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} ;
158158 init = glorot_uniform, groups = 1 ) where N
159159 cin, cout = ch
160+ @assert cin % groups == 0 " Input channel dimension must be divisible by groups."
161+ @assert cout % groups == 0 " Output channel dimension must be divisible by groups."
160162 init (filter... , cin÷ groups, cout)
161163end
162164
@@ -298,91 +300,37 @@ end
298300
299301"""
300302 DepthwiseConv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
303+ DepthwiseConv(weight::AbstractArray, [bias, activation; stride, pad, dilation])
304+
305+ Return a depthwise convolutional layer, that is a [`Conv`](@ref) layer with number of
306+ groups equal to the number of input channels.
301307
302- Depthwise convolutional layer. `filter` is a tuple of integers
303- specifying the size of the convolutional kernel, while
304- `in` and `out` specify the number of input and output channels.
305-
306- Note that `out` must be an integer multiple of `in`.
307-
308- Parameters are controlled by additional keywords, with defaults
309- `init=glorot_uniform` and `bias=true`.
310-
311- See also [`Conv`](@ref) for more detailed description of keywords.
308+ See [`Conv`](@ref) for a description of the arguments.
312309
313310# Examples
311+
314312```jldoctest
315313julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images
316314
317315julia> lay = DepthwiseConv((5,5), 3 => 6, relu; bias=false)
318- DepthwiseConv ((5, 5), 3 => 6, relu, bias=false) # 150 parameters
316+ Conv ((5, 5), 3 => 6, relu, groups=3, bias=false) # 150 parameters
319317
320318julia> lay(xs) |> size
321319(96, 96, 6, 50)
322320
323- julia> DepthwiseConv((5,5), 3 => 9, stride=2, pad=2)(xs) |> size
321+ julia> DepthwiseConv((5, 5), 3 => 9, stride=2, pad=2)(xs) |> size
324322(50, 50, 9, 50)
325323```
326324"""
327- struct DepthwiseConv{N,M,F,A,V}
328- σ:: F
329- weight:: A
330- bias:: V
331- stride:: NTuple{N,Int}
332- pad:: NTuple{M,Int}
333- dilation:: NTuple{N,Int}
325+ function DepthwiseConv (k:: NTuple{<:Any,Integer} , ch:: Pair{<:Integer,<:Integer} , σ = identity;
326+ stride = 1 , pad = 0 , dilation = 1 , bias = true , init = glorot_uniform)
327+ Conv (k, ch, σ; groups= ch. first, stride, pad, dilation, bias, init)
334328end
335329
336- """
337- DepthwiseConv(weight::AbstractArray, [bias, activation; stride, pad, dilation])
338-
339- Constructs a layer with the given weight and bias arrays.
340- Accepts the same keywords as the `DepthwiseConv((4,4), 3 => 6, relu)` method.
341- """
342330function DepthwiseConv (w:: AbstractArray{T,N} , bias = true , σ = identity;
343- stride = 1 , pad = 0 , dilation = 1 ) where {T,N}
344- stride = expand (Val (N- 2 ), stride)
345- dilation = expand (Val (N- 2 ), dilation)
346- pad = calc_padding (DepthwiseConv, pad, size (w)[1 : N- 2 ], dilation, stride)
347- b = create_bias (w, bias, prod (size (w)[N- 1 : end ]))
348- return DepthwiseConv (σ, w, b, stride, pad, dilation)
349- end
350-
351- function DepthwiseConv (k:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} , σ = identity;
352- init = glorot_uniform, stride = 1 , pad = 0 , dilation = 1 ,
353- bias = true ) where N
354- @assert ch[2 ] % ch[1 ] == 0 " Output channels must be integer multiple of input channels"
355- weight = depthwiseconvfilter (k, ch, init = init)
356- return DepthwiseConv (weight, bias, σ; stride, pad, dilation)
357- end
358-
359- @functor DepthwiseConv
360-
361- """
362- depthwiseconvfilter(filter::Tuple, in => out)
363-
364- Constructs a depthwise convolutional weight array defined by `filter` and channels
365- from `in` to `out`.
366-
367- Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
368- distribution.
369-
370- See also: [`convfilter`](@ref)
371- """
372- depthwiseconvfilter (filter:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} ;
373- init = glorot_uniform) where N = init (filter... , div (ch[2 ], ch[1 ]), ch[1 ])
374-
375- function (c:: DepthwiseConv )(x)
376- σ = NNlib. fast_act (c. σ, x)
377- cdims = DepthwiseConvDims (x, c. weight; stride= c. stride, padding= c. pad, dilation= c. dilation)
378- σ .(depthwiseconv (x, c. weight, cdims) .+ conv_reshape_bias (c))
379- end
380-
381- function Base. show (io:: IO , l:: DepthwiseConv )
382- print (io, " DepthwiseConv(" , size (l. weight)[1 : end - 2 ])
383- print (io, " , " , size (l. weight)[end ], " => " , prod (size (l. weight)[end - 1 : end ]))
384- _print_conv_opt (io, l)
385- print (io, " )" )
331+ stride = 1 , pad = 0 , dilation = 1 ) where {T,N}
332+ w2 = reshape (w, size (w)[1 : end - 2 ]. .. , 1 , :)
333+ Conv (w2, bias, σ; groups = size (w)[end - 1 ], stride, pad, dilation)
386334end
387335
388336
0 commit comments