Skip to content

Commit 9029706

Browse files
authored
feat: batched_jacobian for Reactant (#1507)
* feat: batched_jacobian for Reactant [skip ci] * test: setup * feat: lower to a looped batched jacobian * test: batched jacobian now exists * feat: forward mode will work once EnzymeAD/Reactant.jl#1822 lands * test: add tests for batched jacobian * fix: explicit imports
1 parent 7606373 commit 9029706

File tree

10 files changed

+222
-35
lines changed

10 files changed

+222
-35
lines changed

ext/LuxEnzymeExt/LuxEnzymeExt.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,14 @@ using Enzyme: Enzyme, Active, Const, Duplicated
55
using EnzymeCore: EnzymeCore, Forward, Reverse
66
using Functors: fmap
77
using Setfield: @set!, @set
8-
using Static: StaticBool, False, True
8+
using Static: False, True
99

1010
using Lux: Lux, Utils, AutoDiffInternalImpl
1111
using Lux.Training: TrainingBackendCache, TrainState
1212
using MLDataDevices: isleaf
1313

1414
Lux.is_extension_loaded(::Val{:Enzyme}) = true
1515

16-
normalize_backend(::StaticBool, ad::AutoEnzyme) = ad
17-
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Forward)
18-
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Reverse)
19-
20-
annotate_function(::AutoEnzyme{<:Any,Nothing}, f::F) where {F} = f
21-
annotate_function(::AutoEnzyme{<:Any,A}, f::F) where {F,A} = A(f)
22-
2316
struct OOPFunctionWrapper{F}
2417
f::F
2518
end

ext/LuxEnzymeExt/autodiff.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# VJPs
22

33
function _vector_jacobian_product_impl(f::F, ad::AutoEnzyme, x, v, extra_args...) where {F}
4-
ad = normalize_backend(False(), ad)
4+
ad = Utils.normalize_autoenzyme_mode(Reverse, ad)
55
@assert ADTypes.mode(ad) isa ADTypes.ReverseMode "VJPs are only supported in reverse \
66
mode"
77
dx = fmap(zero, x; exclude=isleaf)
88
Enzyme.autodiff(
99
ad.mode,
10-
annotate_function(ad, OOPFunctionWrapper(f)),
10+
Utils.annotate_enzyme_function(ad, OOPFunctionWrapper(f)),
1111
Duplicated(fmap(similar, v; exclude=isleaf), fmap(copy, v; exclude=isleaf)),
1212
Duplicated(x, dx),
1313
extra_args...,
@@ -30,11 +30,13 @@ end
3030
# JVPs
3131

3232
function _jacobian_vector_product_impl(f::F, ad::AutoEnzyme, x, u, extra_args...) where {F}
33-
ad = normalize_backend(True(), ad)
33+
ad = Utils.normalize_autoenzyme_mode(Forward, ad)
3434
@assert ADTypes.mode(ad) isa ADTypes.ForwardMode "JVPs are only supported in forward \
3535
mode"
3636
return only(
37-
Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u), extra_args...)
37+
Enzyme.autodiff(
38+
ad.mode, Utils.annotate_enzyme_function(ad, f), Duplicated(x, u), extra_args...
39+
),
3840
)
3941
end
4042

ext/LuxReactantExt/LuxReactantExt.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
module LuxReactantExt
22

3-
using ADTypes: AutoEnzyme
3+
using ADTypes: ADTypes, AutoEnzyme
44
using Enzyme: Enzyme, Active, Const, Duplicated
5+
using EnzymeCore: EnzymeCore
6+
using LinearAlgebra: LinearAlgebra
57
using Functors: Functors
68
using Preferences: load_preference
79
using Random: Random
810
using Optimisers: Optimisers
911
using Reactant:
10-
Reactant,
11-
Profiler,
12-
@compile,
13-
@code_hlo,
14-
@jit,
15-
@opcall,
16-
AnyTracedRArray,
17-
TracedRArray,
18-
TracedRNumber,
19-
PrecisionConfig
12+
Reactant, Profiler, AnyTracedRArray, TracedRArray, TracedRNumber, PrecisionConfig
13+
using Reactant: @compile, @code_hlo, @jit, @opcall
2014
using ReactantCore: ReactantCore, @trace
2115
using Setfield: @set!
2216
using Static: True, False
@@ -63,6 +57,7 @@ include("training.jl")
6357
include("layers.jl")
6458
include("tracing.jl")
6559
include("saved_model.jl")
60+
include("batched_jacobian.jl")
6661

6762
include("precompile_workloads.jl")
6863

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(
2+
f::F, ad::Lux.Training.ReactantBackend, x
3+
) where {F}
4+
ad = Utils.normalize_autoenzyme_mode(EnzymeCore.Forward, ad.ad)
5+
if ADTypes.mode(ad) isa ADTypes.ReverseMode
6+
return _batched_jacobian_reverse_impl(f, ad, x)
7+
else
8+
return _batched_jacobian_forward_impl(f, ad, x)
9+
end
10+
end
11+
12+
struct ApplyWithReshape{F,SZ}
13+
f::F
14+
sz::SZ
15+
end
16+
17+
(f::ApplyWithReshape)(x) = reshape(f.f(reshape(x, f.sz)), :, size(x, ndims(x)))
18+
19+
function (f::ApplyWithReshape)(y, x)
20+
res = f.f(reshape(x, f.sz))
21+
copyto!(y, reshape(res, size(y)))
22+
return nothing
23+
end
24+
25+
function _check_validity_for_batched_jacobian(f::F, x::AbstractArray) where {F}
26+
y = f(x)
27+
@assert y isa AbstractArray
28+
B = size(y, ndims(y))
29+
if ndims(y) 1 || B != size(x, ndims(x))
30+
throw(AssertionError("`batched_jacobian` only supports batched outputs \
31+
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
32+
end
33+
return y, B
34+
end
35+
36+
function _batched_jacobian_reverse_impl(f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
37+
y, B = _check_validity_for_batched_jacobian(f, x)
38+
f′ = ApplyWithReshape(f, size(x))
39+
40+
y = Utils.contiguous(reshape(y, :, B))
41+
dy = Utils.contiguous(
42+
repeat(
43+
reshape(
44+
Reactant.promote_to(
45+
TracedRArray{Reactant.unwrapped_eltype(y),2},
46+
LinearAlgebra.I(size(y, 1)),
47+
),
48+
size(y, 1),
49+
1,
50+
size(y, 1),
51+
),
52+
1,
53+
size(y, 2),
54+
1,
55+
),
56+
)
57+
58+
x = Utils.contiguous(reshape(x, :, B))
59+
60+
# TODO: replace once https://github.com/LuxDL/Lux.jl/issues/1523 is fixed
61+
#=
62+
dx = similar(x, size(x, 1), size(x, 2), size(y, 1))
63+
fill!(dx, false)
64+
65+
Enzyme.autodiff(
66+
ad.mode,
67+
Utils.annotate_enzyme_function(ad, f′),
68+
Reactant.StackedBatchDuplicated(y, dy),
69+
Reactant.StackedBatchDuplicated(x, dx),
70+
)
71+
72+
return permutedims(dx, (3, 1, 2))
73+
=#
74+
75+
# Our loop to batch pass should automatically batch this loop and current has better
76+
# coverage than the above. Though we should fix the above to ensure we never have a
77+
# loop in the final result.
78+
dx = similar(x, size(y, 1), size(x, 1), size(x, 2))
79+
@trace track_numbers = false for i in 1:size(y, 1)
80+
dxᵢ = Enzyme.make_zero(x)
81+
Enzyme.autodiff(
82+
ad.mode,
83+
Utils.annotate_enzyme_function(ad, f′),
84+
Duplicated,
85+
Duplicated(y, dy[:, :, i]),
86+
Duplicated(x, dxᵢ),
87+
)
88+
dx[i, :, :] = dxᵢ
89+
end
90+
return dx
91+
end
92+
93+
function _batched_jacobian_forward_impl(f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
94+
y, B = _check_validity_for_batched_jacobian(f, x)
95+
y = Utils.contiguous(reshape(y, :, B)) # will be DCEd away
96+
97+
f′ = ApplyWithReshape(f, size(x))
98+
x = Utils.contiguous(reshape(x, :, size(x, ndims(x))))
99+
100+
bx = Utils.contiguous(
101+
repeat(
102+
reshape(
103+
Reactant.promote_to(
104+
TracedRArray{Reactant.unwrapped_eltype(x),2},
105+
LinearAlgebra.I(size(x, 1)),
106+
),
107+
size(x, 1),
108+
1,
109+
size(x, 1),
110+
),
111+
1,
112+
size(x, 2),
113+
1,
114+
),
115+
)
116+
117+
# TODO: replace once https://github.com/LuxDL/Lux.jl/issues/1523 is fixed
118+
# return stack(
119+
# only(
120+
# Enzyme.autodiff(
121+
# ad.mode,
122+
# Utils.annotate_enzyme_function(ad, f′),
123+
# Reactant.StackedBatchDuplicated(x, bx),
124+
# ),
125+
# );
126+
# dims=2,
127+
# )
128+
129+
dy = similar(y, size(y, 1), size(x, 1), size(x, 2))
130+
@trace track_numbers = false for i in 1:size(x, 1)
131+
dy[:, i, :] = only(
132+
Enzyme.autodiff(
133+
ad.mode,
134+
Utils.annotate_enzyme_function(ad, f′),
135+
Duplicated,
136+
Duplicated(x, bx[:, :, i]),
137+
),
138+
)
139+
end
140+
return dy
141+
end

src/autodiff/api.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ the following properties for `y = f(x)`:
9999
100100
| Supported Backends | Packages Needed |
101101
|:------------------ |:--------------- |
102+
| `AutoEnzyme` | `Reactant.jl` |
102103
| `AutoForwardDiff` | |
103104
| `AutoZygote` | `Zygote.jl` |
104105
@@ -126,16 +127,13 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where
126127
throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`."))
127128
end
128129

129-
function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F}
130-
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
131-
end
132-
133-
function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F}
134-
if !is_extension_loaded(Val(:Zygote))
135-
error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \
136-
`$(backend)`.")
130+
for implemented_backend in (:AutoForwardDiff, :AutoZygote, :AutoEnzyme)
131+
@eval function batched_jacobian(
132+
f::F, backend::$implemented_backend, x::AbstractArray
133+
) where {F}
134+
assert_backend_loaded(:batched_jacobian, backend)
135+
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
137136
end
138-
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
139137
end
140138

141139
# Utils

src/autodiff/batched_autodiff.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ end
9090
function batched_jacobian_internal(
9191
f::F, backend::AbstractADType, x::AbstractArray
9292
) where {F}
93-
return batched_jacobian_impl(f, backend, x)
93+
return batched_jacobian_impl(
94+
f, Lux.Training.maybe_wrap_adtype(backend, get_device_type(x)), x
95+
)
9496
end
9597

9698
# ForwardDiff.jl Implementation

src/helpers/training.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ end
162162
@concrete struct ReactantBackend
163163
return_gradients <: StaticBool
164164
sync::Bool
165+
ad <: AutoEnzyme
165166
end
166167

167168
const APPLY_GRAD_DOCSTRING = """
@@ -354,7 +355,7 @@ function maybe_wrap_adtype(
354355
return_gradients::Utils.BoolType=True(),
355356
sync::Bool=false,
356357
)
357-
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients), sync)
358+
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients), sync, ad)
358359
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
359360
Enzyme.jl (`AutoEnzyme`)."))
360361
end

src/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module Utils
22

3+
using ADTypes: ADTypes, AutoEnzyme
34
using ArrayInterface: ArrayInterface
45
using ArgCheck: @argcheck
56
using ChainRulesCore: ChainRulesCore, @non_differentiable, NoTangent
@@ -8,6 +9,7 @@ using EnzymeCore: EnzymeRules
89
using ForwardDiff: Dual
910
using Functors: Functors, fmapstructure
1011
using Random: AbstractRNG
12+
using Setfield: @set
1113
using Static: Static, StaticBool, StaticInteger, StaticSymbol
1214
using StaticArraysCore: SMatrix, SVector
1315

@@ -236,6 +238,12 @@ recursive_unthunk(x) = Functors.fmap(CRC.unthunk, x; exclude=MLDataDevices.islea
236238

237239
convert_eltype(::Type{T}, x::Number) where {T<:Number} = convert(T, x)
238240

241+
normalize_autoenzyme_mode(mode, ad::AutoEnzyme) = ad
242+
normalize_autoenzyme_mode(mode, ad::AutoEnzyme{Nothing}) = @set(ad.mode = mode)
243+
244+
annotate_enzyme_function(::AutoEnzyme{<:Any,Nothing}, f::F) where {F} = f
245+
annotate_enzyme_function(::AutoEnzyme{<:Any,A}, f::F) where {F,A} = A(f)
246+
239247
end
240248

241249
using .Utils:

test/reactant/autodiff_tests.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,51 @@
5555
end
5656
end
5757
end
58+
59+
@testitem "AutoDiff APIs: Batched Jacobian" tags = [:reactant] setup = [SharedTestSetup] begin
60+
using Reactant, Lux, Zygote, Random, Enzyme
61+
62+
rng = Random.default_rng()
63+
64+
models = (
65+
Chain(
66+
Conv((3, 3), 2 => 4, gelu; pad=SamePad()),
67+
Conv((3, 3), 4 => 2, gelu; pad=SamePad()),
68+
FlattenLayer(),
69+
Dense(18 => 2),
70+
),
71+
Chain(Dense(2, 4, gelu), Dense(4, 2)),
72+
)
73+
Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4))
74+
75+
@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
76+
if mode == "amdgpu"
77+
@warn "Skipping AMDGPU tests for Reactant"
78+
continue
79+
end
80+
81+
if ongpu
82+
Reactant.set_default_backend("gpu")
83+
else
84+
Reactant.set_default_backend("cpu")
85+
end
86+
87+
dev = reactant_device(; force=true)
88+
89+
@testset "$(size(X))" for (model, X) in zip(models, Xs)
90+
ps, st = Lux.setup(rng, model)
91+
X_ra = dev(X)
92+
93+
smodel = StatefulLuxLayer(model, ps, st)
94+
smodel_ra = StatefulLuxLayer(model, dev(ps), dev(st))
95+
96+
J = batched_jacobian(smodel, AutoZygote(), X)
97+
J_ra = @jit batched_jacobian(smodel_ra, AutoEnzyme(; mode=Enzyme.Reverse), X_ra)
98+
J_ra2 = @jit batched_jacobian(
99+
smodel_ra, AutoEnzyme(; mode=Enzyme.Forward), X_ra
100+
)
101+
@test J J_ra atol = 1.0e-3 rtol = 1.0e-3
102+
@test J J_ra2 atol = 1.0e-3 rtol = 1.0e-3
103+
end
104+
end
105+
end

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ using Lux
7575
x -> x, AutoZygote(), rand(2), rand(2)
7676
)
7777

78-
@test_throws ArgumentError batched_jacobian(x -> x, AutoEnzyme(), rand(2, 2))
7978
@test_throws ErrorException batched_jacobian(x -> x, AutoZygote(), rand(2, 2))
8079
end
8180

0 commit comments

Comments
 (0)