Skip to content
Open
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## v0.12.9
* Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781).
* Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792).
* Add `step!` as a single training step of `train!` to allow for more exotic
optimisers (#666)

## v0.12.8
* Optimized inference and gradient calculation of OneHotMatrix[pr](https://github.com/FluxML/Flux.jl/pull/1756)
Expand Down
2 changes: 1 addition & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Optimise
using LinearAlgebra
import ArrayInterface

export train!, update!,
export train!, step!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief,
InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
Expand Down
34 changes: 31 additions & 3 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Juno
import Zygote: Params, gradient
import Zygote: Params, withgradient

"""
update!(x, x̄)
Expand Down Expand Up @@ -80,6 +80,35 @@ end
batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x

"""
optimstep!(loss, params, opt)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest optimstep! instead of trainstep! to indicate that this is the optimiser interface and keep the ML jargon to a minimum

Copy link
Member

@mcabbott mcabbott Mar 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One vote for something evoking train! to stress that they are closely related.

If the longer-term plan is to use Optimisers.jl, this may not fit with train! at all -- some recent discussion here: #1902 (comment) . In which case there will be an implicit-style train! & Params story, and an explicit-style gradient and Optimisers.update!. With such a divide, this function wants to be clearly on the train! & Params side.

Maybe it should just be 3-arg train!? Without a data iterator, there is no iteration, that's all:

train!(loss, ::Params, data, ::AbstractOptimiser)  # calls loss(d...) for d in data
train!(loss, ::Params, ::AbstractOptimiser)        # calls loss() since there is no data


`optimstep!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`)
based on a pluggable [Optimisers](@ref) (`opt`). It represents a single step in
the training loop `train!`.

The default implementation for `optimstep!` is takes the gradient of `loss`
and calls `Flux.Optimise.update!` to adjust the parameters, but you can overload
`optimstep!` for specific types of `opt`. This can be useful if your optimization routine
has does not follow the standard gradient descent procedure (e.g. gradient-free optimizers).

Unlike `train!`, the loss function of `optimstep!` accepts no input.
Instead, `train!` cycles through the data in a loop and calls `optimstep!`:
```julia
for d in data
optimstep!(ps, opt) do
loss(d)
end
end
```
If you are writing [Custom Training loops](@ref), then you should follow this pattern.
"""
function optimstep!(loss, params, opt)
val, gs = withgradient(loss, params)
update!(opt, params, gs)
return val, gs
end

"""
train!(loss, params, data, opt; cb)

Expand All @@ -106,10 +135,9 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
@progress for d in data
try
gs = gradient(ps) do
optimstep!(ps, opt) do
loss(batchmemaybe(d)...)
end
update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
Expand Down