-
Notifications
You must be signed in to change notification settings - Fork 38
feat: lower stablehlo to llvm #1893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
avik-pal
commented
Nov 20, 2025
; ModuleID = '__compute_module'
source_filename = "__compute_module"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
@0 = private unnamed_addr constant [4 x i8] zeroinitializer, align 4
!xla_cpu_memory_region_name = !{!0}
!0 = !{!"ir_emitter"}
; ModuleID = '__compute_module_wrapped_reduce_kernel_module'
source_filename = "__compute_module_wrapped_reduce_kernel_module"
%XLA_CPU_KernelCallFrame = type { ptr, ptr, i64, ptr }
%XLA_CPU_KernelArg = type { ptr, i64 }
%kernel_dim3 = type { i64, i64, i64 }
; Function Attrs: uwtable
define ptr @wrapped_reduce(ptr %0) #0 {
%2 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 3
%3 = load ptr, ptr %2, align 8, !invariant.load !2
%4 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 0, i32 0
%5 = load ptr, ptr %4, align 8, !invariant.load !2, !dereferenceable !3
%6 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 1, i32 0
%7 = load ptr, ptr %6, align 8, !invariant.load !2, !dereferenceable !4
%8 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 2, i32 0
%9 = load ptr, ptr %8, align 8, !invariant.load !2, !dereferenceable !4
%10 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 1
%11 = load ptr, ptr %10, align 8
%12 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 0
%13 = load i64, ptr %12, align 4, !invariant.load !2
%14 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 1
%15 = load i64, ptr %14, align 4, !invariant.load !2
%16 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 2
%17 = load i64, ptr %16, align 4, !invariant.load !2
call void @wrapped_reduce_wrapped(ptr %5, ptr %7, ptr %9, i64 %13, i64 %15, i64 %17)
ret ptr null
}
; Function Attrs: alwaysinline
define internal void @wrapped_reduce_wrapped(ptr noalias align 64 dereferenceable(64) %0, ptr noalias align 64 dereferenceable(4) %1, ptr noalias align 64 dereferenceable(4) %2, i64 %3, i64 %4, i64 %5) #1 {
%7 = getelementptr inbounds [1 x float], ptr %1, i32 0, i32 0
%8 = load float, ptr %7, align 4, !invariant.load !2
br label %9
9: ; preds = %25, %6
%10 = phi i64 [ %26, %25 ], [ 0, %6 ]
%11 = phi float [ %17, %25 ], [ %8, %6 ]
%12 = icmp slt i64 %10, 4
br i1 %12, label %13, label %27
13: ; preds = %9
%14 = mul nsw i64 %10, 4
br label %15
15: ; preds = %19, %13
%16 = phi i64 [ %24, %19 ], [ 0, %13 ]
%17 = phi float [ %23, %19 ], [ %11, %13 ]
%18 = icmp slt i64 %16, 4
br i1 %18, label %19, label %25
19: ; preds = %15
%20 = add nsw i64 %14, %16
%21 = getelementptr inbounds [16 x float], ptr %0, i32 0, i64 %20
%22 = load float, ptr %21, align 4, !invariant.load !2
%23 = fadd reassoc float %17, %22
%24 = add i64 %16, 1
br label %15
25: ; preds = %15
%26 = add i64 %10, 1
br label %9, !llvm.loop !5
27: ; preds = %9
%28 = getelementptr inbounds [1 x float], ptr %2, i32 0, i32 0
store float %11, ptr %28, align 4
ret void
}
attributes #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" }
attributes #1 = { alwaysinline }
!llvm.module.flags = !{!0}
!xla_cpu_memory_region_name = !{!1}
!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{!"xla_cpu_emitter__loop_fusion_kernel_emitter__hlo_opcode__fusion"}
!2 = !{}
!3 = !{i64 64}
!4 = !{i64 4}
!5 = distinct !{!5, !6}
!6 = !{!"llvm.loop.unroll.disable"}
|
|
Caveat here is the llvm emitted is different from the native Julia llvm. Cuda.jl has a similar problem and there is an llvm downgrader package for this purpose |
This won't work for serialization right? I think @ChrisRackauckas (correct me if wrong here), wanted more of an AoT compile where the final version (user facing generated file) doesn't have reactant |
|
it should be workable with serialization, but a bit more setup is required. It would be equivalent to how cuda.jl/enzyme.jl would perform external compilation and embed it into julia. |
|
so long as the jit'd code doesn't use the xla runtime, there would be no issue |
|
note also you're going to need all of the parsing of the julia calling convention work from enzyme.jl (see https://github.com/EnzymeAD/Enzyme.jl/blob/main/src/compiler.jl and "PrimalErrorThunk" (which is the "just call the original code" calling conv fixup). I've now finished adapting this to julia 1.12, in principle at least so that's good |
The LLVM downgrader could be used for that. It's expected to get build for the LLVM version you ingest IR for, and you need to backport the bitcode writer for the target version you want to emit (although you can of course safely emit IR for an older version, which Julia's LLVM will auto-upgrade). |