Skip to content

Conversation

@lassepe
Copy link

@lassepe lassepe commented Nov 17, 2025

Adding the test case from EnzymeAD/Enzyme-JAX#1627

Copy link
Collaborator

@Pangoraw Pangoraw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the error logs, it seems f_sincos_jac would need an @allowscalar annotation. But since it crashes Julia, I am not sure it is worth merging just yet.

@lassepe
Copy link
Author

lassepe commented Nov 19, 2025

I created this PR in reponse to @wsmoses comment here: EnzymeAD/Enzyme-JAX#1627 (comment)

Maybe I misunderstood what PR you guys actually wanted.

@wsmoses
Copy link
Member

wsmoses commented Nov 20, 2025

Nah @lassepe you're all good. PR is useful, we just can't merge until the issue is fixed!

@lassepe
Copy link
Author

lassepe commented Nov 20, 2025

I can also mark the test as @test_broken

@wsmoses
Copy link
Member

wsmoses commented Nov 20, 2025

well first we need the allowscalar for it to run.

but also the fix has landed upstream so hopefully in about a day when we have a new jll it will pass here!

@lassepe
Copy link
Author

lassepe commented Nov 21, 2025

I've added the @allowscalar as well as the Enzyme.tupstack dispatch here. I've limited that dispatch to RArray to avoid type piracy.

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2025

@lassepe

Self Qualified Accesses: Error During Test at /home/runner/work/Reactant.jl/Reactant.jl/test/qa.jl:101
  Test threw exception
  Expression: check_no_self_qualified_accesses(Reactant; ignore = (:REACTANT_METHOD_TABLE, :__skip_rewrite_func_set, :__skip_rewrite_func_set_lock, :__skip_rewrite_type_constructor_list, :__skip_rewrite_type_constructor_list_lock)) === nothing
  SelfQualifiedAccessException
  Module `Reactant` has self-qualified accesses:
  - `RArray` was accessed as `Reactant.RArray` inside `Reactant` at `/home/runner/work/Reactant.jl/Reactant.jl/src/Enzyme.jl:73:35`

@wsmoses
Copy link
Member

wsmoses commented Nov 23, 2025

@lassepe your tupstack was transposed, this should resolve:

using Enzyme, Reactant

@inline function Enzyme.tupstack(
    data::Tuple{<:Reactant.RArray, Vararg{<:Reactant.RArray}},
    outshape::Tuple{Vararg{Int}},
    inshape::Tuple{Vararg{Int}},
)
    res = similar(first(data), outshape..., inshape...)
    c = CartesianIndices(inshape)
    tail_dims = map(Returns(:), outshape)
    for (i, val) in enumerate(data)
    	@show i, val, c[i], outshape, inshape
    	@show res[ tail_dims..., Tuple(c[i])...]
    	@show val
        copyto!(@view(res[ tail_dims..., Tuple(c[i])...]), val)
    end
    @show res
    return res
end

f_sincos2(x) = (map(sin, x) + map(cos, reverse(x)))[1:2]

function fwdz(f, x)
	dx = zero(x)
	Reactant.@allowscalar dx[1] = 1
	res = Enzyme.autodiff(Forward, f, Duplicated(x, dx))[1]
	res
end

x = Float64[1, 2, 3, 4]
x_r = Reactant.to_rarray(x)
convert(Array, x_r)

fwdz(f_sincos2, x)

@jit fwdz(f_sincos2, x_r)

@code_hlo optimize=:before_enzyme fwdz(f_sincos2, x_r)

jac2(f, x) = only(Enzyme.jacobian(Enzyme.Forward, f, x))

Enzyme.jacobian(Reverse, f_sincos2, x)[1]

@code_hlo optimize=false jac2(f_sincos2, x_r)

@code_hlo optimize=:before_enzyme jac2(f_sincos2, x_r)

@jit jac2(f_sincos2, x_r)

however an unrelated optimization pass seems to cause breakage (cc @avik-pal if you've seen this recently)

julia> @code_hlo optimize=:before_enzyme jac2(f_sincos2, x_r)
(i, val, c[i], outshape, inshape) = (1, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(1,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
(i, val, c[i], outshape, inshape) = (2, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(2,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
(i, val, c[i], outshape, inshape) = (3, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(3,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
(i, val, c[i], outshape, inshape) = (4, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(4,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
res = TracedRArray{Float64,2N}((), size=(2, 4))
loc("sin/sine"("/Users/wmoses/git/Reactant.jl/src/TracedRNumber.jl":518:0)): error: Mismatched ranks of types1 vs 2
LLVM ERROR: Failed to infer result type(s):
"stablehlo.multiply"(...) {} : (tensor<2xf64>, tensor<4x2xf64>) -> ( ??? )

[73753] signal (6): Abort trap: 6
in expression starting at REPL[14]:1
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 72104031 (Pool: 72018240; Big: 85791); GC: 90
zsh: abort      julia --project

@lassepe
Copy link
Author

lassepe commented Nov 23, 2025

Thanks for catching the transpose bug. If this is the desired order of elements, it may be worth considering the substantially more readable version:

@inline function Enzyme.tupstack(
    data::Tuple{<:Reactant.RArray, Vararg{<:Reactant.RArray}},
    outshape::Tuple{Vararg{Int}},
    inshape::Tuple{Vararg{Int}},
)
    return reshape(stack(data), outshape..., inshape...)
end

This would have the additional benefit of avoiding the risk of unitialized array elements when outshape and inshape are inconsistent with the dimensions of data.

@wsmoses
Copy link
Member

wsmoses commented Nov 24, 2025

[1540201] signal 6 (-6): Aborted                                                                                                               
in expression starting at REPL[15]:1                                                                                                           
pthread_kill at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)                                                                                 
gsignal at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)                                                                                      
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)                                                                                        
unknown function (ip: 0x7d0cfa22881a)                                                                                                          
__assert_fail at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)                                                                                
setInvertedPointer at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp:242                                              
setDiffe at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp:228                                                        
memoryIdentityForwardHandler at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp:172         
createForwardModeTangent at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h:162               
createForwardModeTangent at /proc/self/cwd/bazel-out/k8-dbg/bin/external/enzyme/_virtual_includes/EnzymeOpInterfacesIncGen/MLIR/Interfaces/Auto
DiffOpInterface.h.inc:636                                                                                                                      
createForwardModeTangent at /proc/self/cwd/bazel-out/k8-dbg/bin/external/enzyme/_virtual_includes/EnzymeOpInterfacesIncGen/MLIR/Interfaces/Auto
DiffOpInterface.cpp.inc:63                                                                                                                     
visitChild at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp:318                                                      
CreateForwardDiff at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp:171                                                 
HandleAutoDiff<mlir::enzyme::ForwardDiffOp> at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:175                        
lowerEnzymeCalls at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:370                                                   
operator() at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:413                                                         
operator() at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Visitors.h:304          

@wsmoses
Copy link
Member

wsmoses commented Nov 24, 2025

okay with pending jll I fixed the batch issue.

oddly there's something here still nondeterminstic:

julia> @jit jac2(f_sincos2, x_r)
(i, val, c[i], outshape, inshape) = (1, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(1,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
(i, val, c[i], outshape, inshape) = (2, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(2,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
(i, val, c[i], outshape, inshape) = (3, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(3,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
(i, val, c[i], outshape, inshape) = (4, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(4,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
res = TracedRArray{Float64,2N}((), size=(2, 4))
2×4 ConcretePJRTArray{Float64,2}:
 -0.653644  -0.0        0.0       -0.841471
 -0.0       -0.989992  -0.909297   0.0

julia> @jit jac2(f_sincos2, x_r)
(i, val, c[i], outshape, inshape) = (1, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(1,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
(i, val, c[i], outshape, inshape) = (2, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(2,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
(i, val, c[i], outshape, inshape) = (3, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(3,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
(i, val, c[i], outshape, inshape) = (4, TracedRArray{Float64,1N}((), size=(2,)), CartesianIndex(4,), (2,), (4,))
res[tail_dims..., Tuple(c[i])...] = TracedRArray{Float64,1N}((), size=(2,))
val = TracedRArray{Float64,1N}((), size=(2,))
res = TracedRArray{Float64,2N}((), size=(2, 4))
2×4 ConcretePJRTArray{Float64,2}:
  0.540302   0.0        0.0      0.756802
 -0.0       -0.416147  -0.14112  0.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants