@@ -22,33 +22,43 @@ function (f::ApplyWithReshape)(y, x)
2222 return nothing
2323end
2424
25- function _batched_jacobian_reverse_impl (f:: F , ad :: AutoEnzyme , x:: AbstractArray ) where {F}
25+ function _check_validity_for_batched_jacobian (f:: F , x:: AbstractArray ) where {F}
2626 y = f (x)
2727 @assert y isa AbstractArray
28- if ndims (y) ≤ 1 || size (y, ndims (y)) != size (x, ndims (x))
28+ B = size (y, ndims (y))
29+ if ndims (y) ≤ 1 || B != size (x, ndims (x))
2930 throw (AssertionError (" `batched_jacobian` only supports batched outputs \
3031 (ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))." ))
3132 end
33+ return y, B
34+ end
3235
36+ function _batched_jacobian_reverse_impl (f:: F , ad:: AutoEnzyme , x:: AbstractArray ) where {F}
37+ y, B = _check_validity_for_batched_jacobian (f, x)
3338 f′ = ApplyWithReshape (f, size (x))
3439
35- y = Utils. contiguous (reshape (y, :, size (y, ndims (y))))
36- dy = repeat (
37- reshape (
38- Reactant. promote_to (
39- TracedRArray{Reactant. unwrapped_eltype (y),2 }, LinearAlgebra. I (size (y, 1 ))
40+ y = Utils. contiguous (reshape (y, :, B))
41+ dy = Utils. contiguous (
42+ repeat (
43+ reshape (
44+ Reactant. promote_to (
45+ TracedRArray{Reactant. unwrapped_eltype (y),2 },
46+ LinearAlgebra. I (size (y, 1 )),
47+ ),
48+ size (y, 1 ),
49+ 1 ,
50+ size (y, 1 ),
4051 ),
41- size (y, 1 ),
4252 1 ,
43- size (y, 1 ),
53+ size (y, 2 ),
54+ 1 ,
4455 ),
45- 1 ,
46- size (y, 2 ),
47- 1 ,
4856 )
49- dy = Utils. contiguous (dy)
5057
51- x = Utils. contiguous (reshape (x, :, size (x, ndims (x))))
58+ x = Utils. contiguous (reshape (x, :, B))
59+
60+ # TODO : replace once https://github.com/LuxDL/Lux.jl/issues/1523 is fixed
61+ #=
5262 dx = similar(x, size(x, 1), size(x, 2), size(y, 1))
5363 fill!(dx, false)
5464
@@ -60,35 +70,71 @@ function _batched_jacobian_reverse_impl(f::F, ad::AutoEnzyme, x::AbstractArray)
6070 )
6171
6272 return permutedims(dx, (3, 1, 2))
73+ =#
74+
75+ # Our loop to batch pass should automatically batch this loop and current has better
76+ # coverage than the above. Though we should fix the above to ensure we never have a
77+ # loop in the final result.
78+ dx = similar (x, size (y, 1 ), size (x, 1 ), size (x, 2 ))
79+ @trace track_numbers = false for i in 1 : size (y, 1 )
80+ dxᵢ = Enzyme. make_zero (x)
81+ Enzyme. autodiff (
82+ ad. mode,
83+ Utils. annotate_enzyme_function (ad, f′),
84+ Duplicated (y, dy[:, :, i]),
85+ Duplicated (x, dxᵢ),
86+ )
87+ dx[i, :, :] = dxᵢ
88+ end
89+ return dx
6390end
6491
6592function _batched_jacobian_forward_impl (f:: F , ad:: AutoEnzyme , x:: AbstractArray ) where {F}
93+ y, B = _check_validity_for_batched_jacobian (f, x)
94+ y = Utils. contiguous (reshape (y, :, B)) # will be DCEd away
95+
6696 f′ = ApplyWithReshape (f, size (x))
6797 x = Utils. contiguous (reshape (x, :, size (x, ndims (x))))
6898
69- bx = repeat (
70- reshape (
71- Reactant. promote_to (
72- TracedRArray{Reactant. unwrapped_eltype (x),2 }, LinearAlgebra. I (size (x, 1 ))
99+ bx = Utils. contiguous (
100+ repeat (
101+ reshape (
102+ Reactant. promote_to (
103+ TracedRArray{Reactant. unwrapped_eltype (x),2 },
104+ LinearAlgebra. I (size (x, 1 )),
105+ ),
106+ size (x, 1 ),
107+ 1 ,
108+ size (x, 1 ),
73109 ),
74- size (x, 1 ),
75110 1 ,
76- size (x, 1 ),
111+ size (x, 2 ),
112+ 1 ,
77113 ),
78- 1 ,
79- size (x, 2 ),
80- 1 ,
81- )
82- bx = Utils. contiguous (bx)
83-
84- return stack (
85- only (
86- Enzyme. autodiff (
87- ad. mode,
88- Utils. annotate_enzyme_function (ad, f′),
89- Reactant. StackedBatchDuplicated (x, bx),
90- ),
91- );
92- dims= 2 ,
93114 )
115+
116+ # TODO : replace once https://github.com/LuxDL/Lux.jl/issues/1523 is fixed
117+ # return stack(
118+ # only(
119+ # Enzyme.autodiff(
120+ # ad.mode,
121+ # Utils.annotate_enzyme_function(ad, f′),
122+ # Reactant.StackedBatchDuplicated(x, bx),
123+ # ),
124+ # );
125+ # dims=2,
126+ # )
127+
128+ dy = similar (y, size (y, 1 ), size (x, 1 ), size (x, 2 ))
129+ @trace track_numbers = false for i in 1 : size (x, 1 )
130+ dyᵢ = Enzyme. make_zero (y)
131+ Enzyme. autodiff (
132+ ad. mode,
133+ Utils. annotate_enzyme_function (ad, f′),
134+ Duplicated (y, dyᵢ),
135+ Duplicated (x, bx[:, :, i]),
136+ )
137+ dy[:, i, :] = dyᵢ
138+ end
139+ return dy
94140end
0 commit comments