Skip to content

Commit 46af3a9

Browse files
committed
feat: forward mode will work once EnzymeAD/Reactant.jl#1822 lands
1 parent 5843b7a commit 46af3a9

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

ext/LuxReactantExt/batched_jacobian.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,11 @@ function _batched_jacobian_forward_impl(f::F, ad::AutoEnzyme, x::AbstractArray)
127127

128128
dy = similar(y, size(y, 1), size(x, 1), size(x, 2))
129129
@trace track_numbers = false for i in 1:size(x, 1)
130-
dyᵢ = Enzyme.make_zero(y)
131-
Enzyme.autodiff(
132-
ad.mode,
133-
Utils.annotate_enzyme_function(ad, f′),
134-
Duplicated(y, dyᵢ),
135-
Duplicated(x, bx[:, :, i]),
130+
dy[:, i, :] = only(
131+
Enzyme.autodiff(
132+
ad.mode, Utils.annotate_enzyme_function(ad, f′), Duplicated(x, bx[:, :, i])
133+
),
136134
)
137-
dy[:, i, :] = dyᵢ
138135
end
139136
return dy
140137
end

0 commit comments

Comments
 (0)