Skip to content
Closed

symm op #1777

Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
cca8ff3
symm op
snonk Oct 19, 2025
5a4b3a4
fix: missing dep (#1779)
avik-pal Oct 20, 2025
322bd41
fix: mark enzymexla symbols as exported (#1781)
avik-pal Oct 20, 2025
0fd57e1
feat: julia api to access device properties (#1762)
avik-pal Oct 21, 2025
6c6fcc9
[deps] Make CUDA build more robust (#1782)
giordano Oct 21, 2025
3622657
chore: update ENZYMEXLA_COMMIT hash in WORKSPACE
avik-pal Oct 22, 2025
fdc56b4
Regenerate MLIR Bindings (#1784)
github-actions[bot] Oct 23, 2025
8b05cdc
Some 1.12 fixes
wsmoses Oct 24, 2025
e77039c
Apply suggestions from code review
wsmoses Oct 24, 2025
95e2217
Add ARM64 support flag to Bazel configuration
wsmoses Oct 27, 2025
2452151
Update ENZYMEXLA_COMMIT hash in WORKSPACE
avik-pal Oct 27, 2025
8a06166
[CI] Add workflow for automatically updating Enzyme-JAX (#1793)
giordano Oct 28, 2025
96a819d
Regenerate MLIR Bindings (#1795)
github-actions[bot] Oct 28, 2025
5751b52
Static cuda attempt
wsmoses Oct 19, 2025
180ee83
f
wsmoses Oct 19, 2025
adb655f
Update ENZYMEXLA_COMMIT and ml_toolchain_workspace
wsmoses Oct 26, 2025
024548c
fix
wsmoses Oct 26, 2025
a2abd76
fix
wsmoses Oct 26, 2025
1a86d57
fix
wsmoses Oct 28, 2025
fe6ee8e
bump
wsmoses Oct 28, 2025
54f5d5d
Update ENZYMEXLA_COMMIT and ml_toolchain_workspace
wsmoses Oct 28, 2025
f498044
feat: new jll version + new compiler passes (#1791)
avik-pal Oct 28, 2025
6e1a02c
feat: some more 1.12 support
avik-pal Oct 28, 2025
1737a70
Update ENZYMEXLA_COMMIT hash in WORKSPACE
avik-pal Oct 28, 2025
158b986
fix
wsmoses Oct 28, 2025
e5fc60f
symm op
snonk Oct 19, 2025
155630f
update build
snonk Oct 29, 2025
41e2a67
untrack some
snonk Oct 29, 2025
90412ef
merge
snonk Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ end
src_dir = joinpath(dirname(dirname(@__DIR__)), "src")

for file in [
"Builtin.jl",
"Arith.jl",
"Affine.jl",
"Func.jl",
"Enzyme.jl",
# "Builtin.jl",
# "Arith.jl",
# "Affine.jl",
# "Func.jl",
# "Enzyme.jl",
"EnzymeXLA.jl",
"StableHLO.jl",
"CHLO.jl",
"VHLO.jl",
"Llvm.jl",
"Nvvm.jl",
"Gpu.jl",
"Affine.jl",
"TPU.jl",
"MosaicGPU.jl",
"Triton.jl",
"Shardy.jl",
"MPI.jl",
"MemRef.jl",
"SparseTensor.jl",
# "StableHLO.jl",
# "CHLO.jl",
# "VHLO.jl",
# "Llvm.jl",
# "Nvvm.jl",
# "Gpu.jl",
# "Affine.jl",
# "TPU.jl",
# "MosaicGPU.jl",
# "Triton.jl",
# "Shardy.jl",
# "MPI.jl",
# "MemRef.jl",
# "SparseTensor.jl",
]
build_file(joinpath(src_dir, "mlir", "Dialects", file))
end
Expand Down
28 changes: 28 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,34 @@ end
return TracedRArray{T,N}((), MLIR.IR.result(conv), result_size)
end

@noinline function lapack_symm(
A::TracedRArray{T},
B::TracedRArray{T},
C::TracedRArray{T},
alpha::TracedRNumber{T},
beta::TracedRNumber{T};
side::Symbol,
uplo::Symbol,
location=mlir_stacktrace("lapack_symm", @__FILE__, @__LINE__),
) where {T}
ctx = MLIR.IR.context()
ressize = size(C)
res = MLIR.IR.result(
enzymexla.lapack_symm(
A.mlir_data,
B.mlir_data,
C.mlir_data,
alpha.mlir_data,
beta.mlir_data;
output=mlir_type(TracedRArray{eltype(C),length(ressize)}, ressize),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eltype(c) make this into unwrapped_eltype

side=enzymexlaLapackSideAttrGet(ctx, side == :L ? 1 : 0),
uplo=enzymexlaLapackUploAttrGet(ctx, uplo == :U ? 1 : 0),
location,
),
)
return res
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

construct a TracedRArray before returning (see the dot_general case below)

end

Base.@nospecializeinfer @noinline function dot_general(
@nospecialize(lhs::TracedRArray{T1}),
@nospecialize(rhs::TracedRArray{T2});
Expand Down
32 changes: 32 additions & 0 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,38 @@ function overloaded_mul!(
return C
end

function overloaded_mul!(
@nospecialize(C::TracedRArray{T,2} where {T}),
@nospecialize(A::Symmetric),
@nospecialize(B::AbstractMatrix),
Comment on lines +277 to +279
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add the version where B is Symmetric

α::Number=true,
β::Number=true,
)
# Promote to traced arrays
A = call_with_reactant(Reactant.promote_to, TracedRArray, A.data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A = call_with_reactant(Reactant.promote_to, TracedRArray, A.data)
A = call_with_reactant(Reactant.promote_to, TracedRArray, parent(A))

B = call_with_reactant(Reactant.promote_to, TracedRArray, B)

# Dimension checks
if size(C) != (size(A, 1), size(B, 2))
throw(DimensionMismatch("C=$(size(C)), A=$(size(A)), B=$(size(B))"))
end

T = Reactant.unwrapped_eltype(C)
tmp = @opcall lapack_symm(
T.(materialize_traced_array(A)),
T.(materialize_traced_array(B)),
T.(materialize_traced_array(C)),
Reactant.promote_to(TracedRNumber{T}, α),
Reactant.promote_to(TracedRNumber{T}, β),
side=:L,
uplo=:U,
)

set_mlir_data!(C, get_mlir_data(tmp)) # TODO remove later, handling in place ops are weird
return C
end


function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1)
iota_2 = @opcall subtract(
Expand Down
31 changes: 31 additions & 0 deletions test/integration/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,34 @@ end
1e-2
end
end

@testset "Symmetric Multiplication" begin
@testset "F32" begin
A = Symmetric(rand(Float32,(10,10)))
B = rand(Float32,(10,10))
C = rand(Float32,(10,10))
A_ra = Reactant.to_rarray(A)
B_ra = Reactant.to_rarray(B)
C_ra = Reactant.to_rarray(C)

alpha = rand(Float32)
beta = rand(Float32)

@test @code_hlo optimize=false A_ra * B_ra * alpha

end
@testset "F64" begin
A = Symmetric(rand(Float64,(10,10)))
B = rand(Float64,(10,10))
C = rand(Float64,(10,10))
A_ra = Reactant.to_rarray(A)
B_ra = Reactant.to_rarray(B)
C_ra = Reactant.to_rarray(C)

alpha = rand(Float64)
beta = rand(Float64)

@test @code_hlo optimize=false A_ra * B_ra * alpha

end
end