Skip to content

Conversation

@avik-pal
Copy link
Collaborator

using Reactant
# Reactant.set_default_backend("cpu")

x_ra = Reactant.to_rarray(rand(Float32, 4, 4))

hlo_module = @code_xla before_xla_optimizations=true sum(x_ra)

result_str_cpu = @ccall Reactant.MLIR.API.mlir_c.CompileMLIRtoLLVMIRWithXLA(
    hlo_module.ptr::Ptr{Cvoid},
    "cpu"::Cstring,
    "llvm-before-optimizations"::Cstring
)::Cstring
println(Base.unsafe_string(result_str_cpu))

result_str_cuda = @ccall Reactant.MLIR.API.mlir_c.CompileMLIRtoLLVMIRWithXLA(
    hlo_module.ptr::Ptr{Cvoid},
    "gpu"::Cstring,
    "llvm-after-optimizations"::Cstring
)::Cstring
println(Base.unsafe_string(result_str_cuda))

@avik-pal
Copy link
Collaborator Author

; ModuleID = '__compute_module'
source_filename = "__compute_module"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"

@0 = private unnamed_addr constant [4 x i8] zeroinitializer, align 4

!xla_cpu_memory_region_name = !{!0}

!0 = !{!"ir_emitter"}

; ModuleID = '__compute_module_wrapped_reduce_kernel_module'
source_filename = "__compute_module_wrapped_reduce_kernel_module"

%XLA_CPU_KernelCallFrame = type { ptr, ptr, i64, ptr }
%XLA_CPU_KernelArg = type { ptr, i64 }
%kernel_dim3 = type { i64, i64, i64 }

; Function Attrs: uwtable
define ptr @wrapped_reduce(ptr %0) #0 {
  %2 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 3
  %3 = load ptr, ptr %2, align 8, !invariant.load !2
  %4 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 0, i32 0
  %5 = load ptr, ptr %4, align 8, !invariant.load !2, !dereferenceable !3
  %6 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 1, i32 0
  %7 = load ptr, ptr %6, align 8, !invariant.load !2, !dereferenceable !4
  %8 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 2, i32 0
  %9 = load ptr, ptr %8, align 8, !invariant.load !2, !dereferenceable !4
  %10 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 1
  %11 = load ptr, ptr %10, align 8
  %12 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 0
  %13 = load i64, ptr %12, align 4, !invariant.load !2
  %14 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 1
  %15 = load i64, ptr %14, align 4, !invariant.load !2
  %16 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 2
  %17 = load i64, ptr %16, align 4, !invariant.load !2
  call void @wrapped_reduce_wrapped(ptr %5, ptr %7, ptr %9, i64 %13, i64 %15, i64 %17)
  ret ptr null
}

; Function Attrs: alwaysinline
define internal void @wrapped_reduce_wrapped(ptr noalias align 64 dereferenceable(64) %0, ptr noalias align 64 dereferenceable(4) %1, ptr noalias align 64 dereferenceable(4) %2, i64 %3, i64 %4, i64 %5) #1 {
  %7 = getelementptr inbounds [1 x float], ptr %1, i32 0, i32 0
  %8 = load float, ptr %7, align 4, !invariant.load !2
  br label %9

9:                                                ; preds = %25, %6
  %10 = phi i64 [ %26, %25 ], [ 0, %6 ]
  %11 = phi float [ %17, %25 ], [ %8, %6 ]
  %12 = icmp slt i64 %10, 4
  br i1 %12, label %13, label %27

13:                                               ; preds = %9
  %14 = mul nsw i64 %10, 4
  br label %15

15:                                               ; preds = %19, %13
  %16 = phi i64 [ %24, %19 ], [ 0, %13 ]
  %17 = phi float [ %23, %19 ], [ %11, %13 ]
  %18 = icmp slt i64 %16, 4
  br i1 %18, label %19, label %25

19:                                               ; preds = %15
  %20 = add nsw i64 %14, %16
  %21 = getelementptr inbounds [16 x float], ptr %0, i32 0, i64 %20
  %22 = load float, ptr %21, align 4, !invariant.load !2
  %23 = fadd reassoc float %17, %22
  %24 = add i64 %16, 1
  br label %15

25:                                               ; preds = %15
  %26 = add i64 %10, 1
  br label %9, !llvm.loop !5

27:                                               ; preds = %9
  %28 = getelementptr inbounds [1 x float], ptr %2, i32 0, i32 0
  store float %11, ptr %28, align 4
  ret void
}

attributes #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" }
attributes #1 = { alwaysinline }

!llvm.module.flags = !{!0}
!xla_cpu_memory_region_name = !{!1}

!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{!"xla_cpu_emitter__loop_fusion_kernel_emitter__hlo_opcode__fusion"}
!2 = !{}
!3 = !{i64 64}
!4 = !{i64 4}
!5 = distinct !{!5, !6}
!6 = !{!"llvm.loop.unroll.disable"}

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2025

Caveat here is the llvm emitted is different from the native Julia llvm. Cuda.jl has a similar problem and there is an llvm downgrader package for this purpose

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2025

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2025

cc @maleadt. We’re hoping to downgrade a future llvm to the Julia llvm version, if you have any thoughts

alternatively @avik-pal we could jit the llvm inside reactant then send the jit’d pointer to Julia and ccall it

@avik-pal
Copy link
Collaborator Author

alternatively @avik-pal we could jit the llvm inside reactant then send the jit’d pointer to Julia and ccall it

This won't work for serialization right? I think @ChrisRackauckas (correct me if wrong here), wanted more of an AoT compile where the final version (user facing generated file) doesn't have reactant

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2025

it should be workable with serialization, but a bit more setup is required. It would be equivalent to how cuda.jl/enzyme.jl would perform external compilation and embed it into julia.

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2025

so long as the jit'd code doesn't use the xla runtime, there would be no issue

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2025

note also you're going to need all of the parsing of the julia calling convention work from enzyme.jl (see https://github.com/EnzymeAD/Enzyme.jl/blob/main/src/compiler.jl and "PrimalErrorThunk" (which is the "just call the original code" calling conv fixup). I've now finished adapting this to julia 1.12, in principle at least so that's good

@maleadt
Copy link

maleadt commented Nov 24, 2025

We’re hoping to downgrade a future llvm to the Julia llvm version, if you have any thoughts

The LLVM downgrader could be used for that. It's expected to get build for the LLVM version you ingest IR for, and you need to backport the bitcode writer for the target version you want to emit (although you can of course safely emit IR for an older version, which Julia's LLVM will auto-upgrade).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants