@@ -33,7 +33,7 @@ _filter_children(f, children::NamedTuple) =
3333_filter_children (f, children) = filter (f, children)
3434
3535"""
36- loadmodel!(dst, src)
36+ loadmodel!(dst, src; filter = _ -> true )
3737
3838Copy all the parameters (trainable and non-trainable) from `src` into `dst`.
3939
@@ -43,9 +43,12 @@ Non-array elements (such as activation functions) are not copied and need not ma
4343Zero bias vectors and `bias=false` are considered equivalent
4444(see extended help for more details).
4545
46+ Specify the predicate function `filter` to control what is recursed.
47+ A child node `x` in either `dst` and `src` is skipped when `filter(x) == false`.
48+
4649# Examples
4750```julia
48- julia> dst = Chain(Dense(Flux.ones32(2, 5, tanh) ), Dense(2 => 1; bias = [1f0]))
51+ julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh ), Dense(2 => 1; bias = [1f0]))
4952Chain(
5053 Dense(5 => 2, tanh), # 12 parameters
5154 Dense(2 => 1), # 3 parameters
6366
6467julia> iszero(dst[2].bias)
6568true
69+
70+ julia> src = Chain(Dense(5 => 2), Dropout(0.2), Dense(2 => 1))
71+ Chain(
72+ Dense(5 => 2), # 12 parameters
73+ Dropout(0.2),
74+ Dense(2 => 1), # 3 parameters
75+ ) # Total: 4 arrays, 15 parameters, 348 bytes.
76+
77+ julia> Flux.loadmodel!(dst, src; filter = x -> !(x isa Dropout)) # skips loading Dropout
78+ Chain(
79+ Dense(5 => 2, tanh), # 12 parameters
80+ Dense(2 => 1), # 3 parameters
81+ ) # Total: 4 arrays, 15 parameters, 316 bytes.
6682```
6783
6884# Extended help
0 commit comments