Skip to content

Commit 7b3face

Browse files
committed
feat: lower to a looped batched jacobian
1 parent 62c74c8 commit 7b3face

File tree

2 files changed

+82
-36
lines changed

2 files changed

+82
-36
lines changed

ext/LuxReactantExt/batched_jacobian.jl

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,43 @@ function (f::ApplyWithReshape)(y, x)
2222
return nothing
2323
end
2424

25-
function _batched_jacobian_reverse_impl(f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
25+
function _check_validity_for_batched_jacobian(f::F, x::AbstractArray) where {F}
2626
y = f(x)
2727
@assert y isa AbstractArray
28-
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
28+
B = size(y, ndims(y))
29+
if ndims(y) 1 || B != size(x, ndims(x))
2930
throw(AssertionError("`batched_jacobian` only supports batched outputs \
3031
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
3132
end
33+
return y, B
34+
end
3235

36+
function _batched_jacobian_reverse_impl(f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
37+
y, B = _check_validity_for_batched_jacobian(f, x)
3338
f′ = ApplyWithReshape(f, size(x))
3439

35-
y = Utils.contiguous(reshape(y, :, size(y, ndims(y))))
36-
dy = repeat(
37-
reshape(
38-
Reactant.promote_to(
39-
TracedRArray{Reactant.unwrapped_eltype(y),2}, LinearAlgebra.I(size(y, 1))
40+
y = Utils.contiguous(reshape(y, :, B))
41+
dy = Utils.contiguous(
42+
repeat(
43+
reshape(
44+
Reactant.promote_to(
45+
TracedRArray{Reactant.unwrapped_eltype(y),2},
46+
LinearAlgebra.I(size(y, 1)),
47+
),
48+
size(y, 1),
49+
1,
50+
size(y, 1),
4051
),
41-
size(y, 1),
4252
1,
43-
size(y, 1),
53+
size(y, 2),
54+
1,
4455
),
45-
1,
46-
size(y, 2),
47-
1,
4856
)
49-
dy = Utils.contiguous(dy)
5057

51-
x = Utils.contiguous(reshape(x, :, size(x, ndims(x))))
58+
x = Utils.contiguous(reshape(x, :, B))
59+
60+
# TODO: replace once https://github.com/LuxDL/Lux.jl/issues/1523 is fixed
61+
#=
5262
dx = similar(x, size(x, 1), size(x, 2), size(y, 1))
5363
fill!(dx, false)
5464
@@ -60,35 +70,71 @@ function _batched_jacobian_reverse_impl(f::F, ad::AutoEnzyme, x::AbstractArray)
6070
)
6171
6272
return permutedims(dx, (3, 1, 2))
73+
=#
74+
75+
# Our loop to batch pass should automatically batch this loop and current has better
76+
# coverage than the above. Though we should fix the above to ensure we never have a
77+
# loop in the final result.
78+
dx = similar(x, size(y, 1), size(x, 1), size(x, 2))
79+
@trace track_numbers = false for i in 1:size(y, 1)
80+
dxᵢ = Enzyme.make_zero(x)
81+
Enzyme.autodiff(
82+
ad.mode,
83+
Utils.annotate_enzyme_function(ad, f′),
84+
Duplicated(y, dy[:, :, i]),
85+
Duplicated(x, dxᵢ),
86+
)
87+
dx[i, :, :] = dxᵢ
88+
end
89+
return dx
6390
end
6491

6592
function _batched_jacobian_forward_impl(f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
93+
y, B = _check_validity_for_batched_jacobian(f, x)
94+
y = Utils.contiguous(reshape(y, :, B)) # will be DCEd away
95+
6696
f′ = ApplyWithReshape(f, size(x))
6797
x = Utils.contiguous(reshape(x, :, size(x, ndims(x))))
6898

69-
bx = repeat(
70-
reshape(
71-
Reactant.promote_to(
72-
TracedRArray{Reactant.unwrapped_eltype(x),2}, LinearAlgebra.I(size(x, 1))
99+
bx = Utils.contiguous(
100+
repeat(
101+
reshape(
102+
Reactant.promote_to(
103+
TracedRArray{Reactant.unwrapped_eltype(x),2},
104+
LinearAlgebra.I(size(x, 1)),
105+
),
106+
size(x, 1),
107+
1,
108+
size(x, 1),
73109
),
74-
size(x, 1),
75110
1,
76-
size(x, 1),
111+
size(x, 2),
112+
1,
77113
),
78-
1,
79-
size(x, 2),
80-
1,
81-
)
82-
bx = Utils.contiguous(bx)
83-
84-
return stack(
85-
only(
86-
Enzyme.autodiff(
87-
ad.mode,
88-
Utils.annotate_enzyme_function(ad, f′),
89-
Reactant.StackedBatchDuplicated(x, bx),
90-
),
91-
);
92-
dims=2,
93114
)
115+
116+
# TODO: replace once https://github.com/LuxDL/Lux.jl/issues/1523 is fixed
117+
# return stack(
118+
# only(
119+
# Enzyme.autodiff(
120+
# ad.mode,
121+
# Utils.annotate_enzyme_function(ad, f′),
122+
# Reactant.StackedBatchDuplicated(x, bx),
123+
# ),
124+
# );
125+
# dims=2,
126+
# )
127+
128+
dy = similar(y, size(y, 1), size(x, 1), size(x, 2))
129+
@trace track_numbers = false for i in 1:size(x, 1)
130+
dyᵢ = Enzyme.make_zero(y)
131+
Enzyme.autodiff(
132+
ad.mode,
133+
Utils.annotate_enzyme_function(ad, f′),
134+
Duplicated(y, dyᵢ),
135+
Duplicated(x, bx[:, :, i]),
136+
)
137+
dy[:, i, :] = dyᵢ
138+
end
139+
return dy
94140
end

test/reactant/autodiff_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
@testitem "AutoDiff APIs: Batched Jacobian" tags = [:reactant] setup = [SharedTestSetup] begin
6060
using Reactant, Lux, Zygote, Random, ForwardDiff, Enzyme
6161

62-
fn(x) = reshape(sum(abs2, x; dims=(1, 2, 3)), 1, :)
62+
fn(x) = reshape(sum(abs2, x; dims=(2, 3)), size(x, 1), :)
6363

6464
rng = Random.default_rng()
6565

0 commit comments

Comments
 (0)