Skip to content
5 changes: 3 additions & 2 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Juno
import Zygote: Params, gradient
import Zygote: Params, withgradient

"""
update!(x, x̄)
Expand Down Expand Up @@ -102,8 +102,9 @@ for d in data

"""
function step!(loss, params, opt)
gs = gradient(loss, params)
val, gs = withgradient(loss, params)
update!(opt, params, gs)
return val, gs
end

"""
Expand Down