Skip to content
Open
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## v0.12.9
* Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781).
* Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792).
* Add `step!` as a single training step of `train!` to allow for more exotic
optimisers (#666)

## v0.12.8
* Optimized inference and gradient calculation of OneHotMatrix[pr](https://github.com/FluxML/Flux.jl/pull/1756)
Expand Down
2 changes: 1 addition & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Optimise
using LinearAlgebra
import ArrayInterface

export train!, update!,
export train!, step!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief,
InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
Expand Down
29 changes: 27 additions & 2 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ end
batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x

"""
step!(loss, params, opt)

`step!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`)
based on a pluggable [Optimisers](@ref) (`opt`). It represents a single step in
the training loop `train!`. While there is a default implementation for
optimisers which are based on the `update!` function and only require gradient
information, this `step!` has to be overloaded for more general optimisers.

While the loss function of `train!` still accepts data as input, the loss function
of `step!` accepts no input. `train!` cycles through the data in a loop
roughly like this

```julia
for d in data
step!(ps, opt) do
loss(d)
end
```

"""
function step!(loss, params, opt)
gs = gradient(loss, params)
update!(opt, params, gs)
end

"""
train!(loss, params, data, opt; cb)

Expand All @@ -106,10 +132,9 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
@progress for d in data
try
gs = gradient(ps) do
step!(ps, opt) do
loss(batchmemaybe(d)...)
end
update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
Expand Down