Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ include("functor.jl")
# Pirate error to catch a common mistake.
Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.")

include("layers/types.jl")
include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
Expand Down
37 changes: 10 additions & 27 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

"""
Chain(layers...)
Chain(name = layer, ...)
Expand Down Expand Up @@ -32,7 +33,7 @@ For large models, there is a special type-unstable path which can reduce compila
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
This feature is somewhat experimental, beware!
"""
struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}}
struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}} <: ContainerLayer
layers::T
end

Expand All @@ -46,8 +47,6 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex

@functor Chain

(c::Chain)(x) = _applychain(c.layers, x)

@generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
Expand Down Expand Up @@ -150,7 +149,7 @@ julia> Flux.params(d1) # no trainable bias
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
```
"""
struct Dense{F, M<:AbstractMatrix, B}
struct Dense{F, M<:AbstractMatrix, B} <: SimpleLayer
weight::M
bias::B
σ::F
Expand All @@ -165,8 +164,6 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity;
Dense(init(out, in), bias, σ)
end

@functor Dense

function (a::Dense)(x::AbstractVecOrMat)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
return σ.(a.weight * x .+ a.bias)
Expand Down Expand Up @@ -223,7 +220,7 @@ julia> Flux.params(b)
Params([[1 2 3 4]])
```
"""
struct Scale{F, A<:AbstractArray, B}
struct Scale{F, A<:AbstractArray, B} <: SimpleLayer
scale::A
bias::B
σ::F
Expand All @@ -236,8 +233,6 @@ end
Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act)
Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end])

@functor Scale

function (a::Scale)(x::AbstractArray)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
σ.(a.scale .* x .+ a.bias)
Expand Down Expand Up @@ -285,14 +280,12 @@ julia> Flux.outputsize(m3, (5, 11))
(7, 11)
```
"""
struct Maxout{T<:Tuple}
struct Maxout{T<:Tuple} <: ContainerLayer
layers::T
end
Maxout(layers...) = Maxout(layers)
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)

@functor Maxout

function (mo::Maxout)(input::AbstractArray)
# Perhaps surprisingly, pairwise max broadcast is often faster,
# even with Zygote. See #698 and #1794
Expand Down Expand Up @@ -333,13 +326,11 @@ true

See also [`Parallel`](@ref), [`Maxout`](@ref).
"""
struct SkipConnection{T,F}
struct SkipConnection{T,F} <: ContainerLayer
layers::T
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@functor SkipConnection

function (skip::SkipConnection)(input)
skip.connection(skip.layers(input), input)
end
Expand Down Expand Up @@ -397,7 +388,7 @@ julia> Flux.Bilinear(rand(4,8,16), false, tanh) # first dim of weight is the ou
Bilinear((8, 16) => 4, tanh; bias=false) # 512 parameters
```
"""
struct Bilinear{F,A,B}
struct Bilinear{F,A,B} <: SimpleLayer
weight::A
bias::B
σ::F
Expand All @@ -408,8 +399,6 @@ struct Bilinear{F,A,B}
end
end

@functor Bilinear

function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity;
bias = true, init = glorot_uniform)
Bilinear(init(out, in1, in2), bias, σ)
Expand Down Expand Up @@ -492,7 +481,7 @@ julia> model2[:β] == model2[2]
true
```
"""
struct Parallel{F, T<:Union{Tuple, NamedTuple}}
struct Parallel{F, T<:Union{Tuple, NamedTuple}} <: ContainerLayer
connection::F
layers::T
end
Expand All @@ -507,8 +496,6 @@ function Parallel(connection; kw...)
Parallel(connection, layers)
end

@functor Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)

Expand Down Expand Up @@ -582,7 +569,7 @@ end

A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
"""
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}} <: ContainerLayer
connection::F
layers::T
end
Expand Down Expand Up @@ -628,8 +615,6 @@ end
end
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)

@functor PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
Expand Down Expand Up @@ -672,12 +657,10 @@ julia> model(vocab_idxs) == model(x)
true
```
"""
struct Embedding{W}
struct Embedding{W} <: SimpleLayer
weight::W
end

@functor Embedding

Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))

(m::Embedding)(x::Integer) = m.weight[:, x]
Expand Down
12 changes: 3 additions & 9 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ julia> Conv((5,5), 3 => 7; stride = 2, dilation = 4)(xs) |> size
(42, 42, 7, 50)
```
"""
struct Conv{N,M,F,A,V}
struct Conv{N,M,F,A,V} <: SimpleLayer
σ::F
weight::A
bias::V
Expand Down Expand Up @@ -187,8 +187,6 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init(filter..., cin÷groups, cout)
end

@functor Conv

conv_dims(c::Conv, x::AbstractArray) =
DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)

Expand Down Expand Up @@ -252,7 +250,7 @@ julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size
(300, 300, 7, 50)
```
"""
struct ConvTranspose{N,M,F,A,V}
struct ConvTranspose{N,M,F,A,V} <: SimpleLayer
σ::F
weight::A
bias::V
Expand Down Expand Up @@ -307,8 +305,6 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
end

@functor ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
Expand Down Expand Up @@ -407,7 +403,7 @@ julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size
(34, 32, 7, 50)
```
"""
struct CrossCor{N,M,F,A,V}
struct CrossCor{N,M,F,A,V} <: SimpleLayer
σ::F
weight::A
bias::V
Expand Down Expand Up @@ -453,8 +449,6 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden
return CrossCor(weight, bias, σ; stride, pad, dilation)
end

@functor CrossCor

function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
return conv(x, w, ddims)
Expand Down
29 changes: 6 additions & 23 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ julia> isapprox(count(==(0), y) / length(y), 0.5, atol=0.1)
true
```
"""
mutable struct Dropout{F,D,R<:AbstractRNG}
mutable struct Dropout{F,D,R<:AbstractRNG} <: NoTrainLayer
p::F
dims::D
active::Union{Bool, Nothing}
Expand All @@ -103,9 +103,6 @@ function Dropout(p; dims=:, rng = rng_from_array())
Dropout(p, dims, nothing, rng)
end

@functor Dropout
trainable(a::Dropout) = (;)

function (a::Dropout)(x)
_isactive(a) || return x
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
Expand Down Expand Up @@ -146,7 +143,7 @@ julia> isapprox(std(x), std(y), atol=0.2)
true
```
"""
mutable struct AlphaDropout{F,R<:AbstractRNG}
mutable struct AlphaDropout{F,R<:AbstractRNG} <: NoTrainLayer
p::F
active::Union{Bool, Nothing}
rng::R
Expand All @@ -158,9 +155,6 @@ end
AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)

@functor AlphaDropout
trainable(a::AlphaDropout) = (;)

function (a::AlphaDropout)(x::AbstractArray{T}) where T
_isactive(a) || return x
p = a.p
Expand Down Expand Up @@ -209,7 +203,7 @@ julia> isapprox(std(y, dims=1:3), ones(1, 1, 1, 2), atol=0.1) && std(y, dims=1:3
true
```
"""
struct LayerNorm{F,D,T,N}
struct LayerNorm{F,D,T,N} <: PartialTrainLayer{(:diag,)}
λ::F
diag::D
ϵ::T
Expand All @@ -224,8 +218,6 @@ end
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)

@functor LayerNorm

(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))

function Base.show(io::IO, l::LayerNorm)
Expand Down Expand Up @@ -322,7 +314,7 @@ julia> isapprox(std(m(xs)), 1, atol=0.1) && std(xs) != std(m(xs))
true
```
"""
mutable struct BatchNorm{F,V,N,W}
mutable struct BatchNorm{F,V,N,W} <: PartialTrainLayer{(:β, :γ)}
λ::F # activation function
β::V # bias
γ::V # scale
Expand Down Expand Up @@ -352,9 +344,6 @@ function BatchNorm(chs::Int, λ=identity;
nothing, chs)
end

@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
Comment on lines -355 to -356
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here PartialTrainLayer{(:β, :γ)} fixes which fields are trainable, permanently. That's OK, since β = nothing when it's not trainable, so it'll be ignored. It improves type-stability although I doubt this matters at all.


function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
N = ndims(x)
Expand Down Expand Up @@ -412,7 +401,7 @@ julia> isapprox(std(y, dims=1:2), ones(1, 1, 3, 2), atol=0.2) && std(y, dims=1:2
true
```
"""
mutable struct InstanceNorm{F,V,N,W}
mutable struct InstanceNorm{F,V,N,W} <: PartialTrainLayer{(:β, :γ)}
λ::F # activation function
β::V # bias
γ::V # scale
Expand Down Expand Up @@ -442,9 +431,6 @@ function InstanceNorm(chs::Int, λ=identity;
nothing, chs)
end

@functor InstanceNorm
trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)

function (l::InstanceNorm)(x)
@assert ndims(x) > 2
@assert size(x, ndims(x)-1) == l.chs
Expand Down Expand Up @@ -506,7 +492,7 @@ julia> isapprox(std(y[:, :, 3:4, 2]), 1, atol=0.1) && std(xs[:, :, 3:4, 2]) != s
true
```
"""
mutable struct GroupNorm{F,V,N,W}
mutable struct GroupNorm{F,V,N,W} <: PartialTrainLayer{(:β, :γ)}
G::Int # number of groups
λ::F # activation function
β::V # bias
Expand All @@ -521,9 +507,6 @@ mutable struct GroupNorm{F,V,N,W}
chs::Int # number of channels
end

@functor GroupNorm
trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
Expand Down
Loading