|
1 | | -## |
2 | | -# Implements contextual dispatch through Cassette.jl |
3 | | -# Goals: |
4 | | -# - Rewrite common CPU functions to appropriate GPU intrinsics |
5 | | -# |
| 1 | +# contextual dispatch using Cassette.jl |
| 2 | +# |
6 | 3 | # TODO: |
7 | 4 | # - error (erf, ...) |
8 | 5 | # - pow |
|
15 | 12 |
|
16 | 13 | using Cassette |
17 | 14 |
|
18 | | -@inline function unknowably_false() |
19 | | - Base.llvmcall("ret i8 0", Bool, Tuple{}) |
20 | | -end |
| 15 | +@inline unknowably_false() = Base.llvmcall("ret i8 0", Bool, Tuple{}) |
| 16 | + |
| 17 | +function generate_transform(method_redefinitions) |
| 18 | + return function transform(ctx, ref) |
| 19 | + CI = ref.code_info |
21 | 20 |
|
22 | | -function transform(ctx, ref) |
23 | | - CI = ref.code_info |
24 | | - noinline = any(@nospecialize(x) -> |
25 | | - Core.Compiler.isexpr(x, :meta) && |
26 | | - x.args[1] == :noinline, |
27 | | - CI.code) |
28 | | - CI.inlineable = !noinline |
| 21 | + # inline everything |
| 22 | + noinline = any(@nospecialize(x) -> |
| 23 | + Core.Compiler.isexpr(x, :meta) && |
| 24 | + x.args[1] == :noinline, |
| 25 | + CI.code) |
| 26 | + CI.inlineable = !noinline |
29 | 27 |
|
30 | | - if isinteractive() |
31 | | - # 265 fix, insert a call to the original method |
32 | | - # that we later will remove with LLVM's DCE |
33 | | - # TODO: We also don't want to compile these functions |
34 | | - unknowably_false = GlobalRef(@__MODULE__, :unknowably_false) |
35 | | - Cassette.insert_statements!(CI.code, CI.codelocs, |
36 | | - (x, i) -> i == 1 ? 4 : nothing, |
37 | | - (x, i) -> i == 1 ? [ |
38 | | - Expr(:call, Expr(:nooverdub, unknowably_false)), |
39 | | - Expr(:gotoifnot, Core.SSAValue(i), i+3), |
40 | | - Expr(:call, Expr(:nooverdub, Core.SlotNumber(1)), (Core.SlotNumber(i) for i in 2:ref.method.nargs)...), |
41 | | - x] : nothing) |
| 28 | + if method_redefinitions |
| 29 | + # 265 fix, insert a call to the original method |
| 30 | + # that we later will remove with LLVM's DCE |
| 31 | + # TODO: We also don't want to compile these functions |
| 32 | + unknowably_false = GlobalRef(@__MODULE__, :unknowably_false) |
| 33 | + Cassette.insert_statements!(CI.code, CI.codelocs, |
| 34 | + (x, i) -> i == 1 ? 4 : nothing, |
| 35 | + (x, i) -> i == 1 ? [ |
| 36 | + Expr(:call, Expr(:nooverdub, unknowably_false)), |
| 37 | + Expr(:gotoifnot, Core.SSAValue(i), i+3), |
| 38 | + Expr(:call, Expr(:nooverdub, Core.SlotNumber(1)), (Core.SlotNumber(i) for i in 2:ref.method.nargs)...), |
| 39 | + x] : nothing) |
| 40 | + end |
| 41 | + CI.ssavaluetypes = length(CI.code) |
| 42 | + |
| 43 | + #Core.Compiler.validate_code(CI) |
| 44 | + return CI |
42 | 45 | end |
43 | | - CI.ssavaluetypes = length(CI.code) |
44 | | - # Core.Compiler.validate_code(CI) |
45 | | - return CI |
46 | 46 | end |
47 | 47 |
|
48 | | -const InlinePass = Cassette.@pass transform |
| 48 | +const StaticPass = Cassette.@pass generate_transform(false) |
| 49 | +const InteractivePass = Cassette.@pass generate_transform(true) |
49 | 50 |
|
50 | 51 | Cassette.@context CUDACtx |
51 | | -const cudactx = Cassette.disablehooks(CUDACtx(pass = InlinePass)) |
| 52 | +const StaticCtx = Cassette.disablehooks(CUDACtx(pass = StaticPass)) |
| 53 | +const InteractiveCtx = Cassette.disablehooks(CUDACtx(pass = InteractivePass)) |
| 54 | + |
| 55 | +@inline function contextualize(f::F, interactive) where F |
| 56 | + ctx = interactive ? InteractiveCtx : StaticCtx |
| 57 | + (args...) -> Cassette.overdub(ctx, f, args...) |
| 58 | +end |
52 | 59 |
|
53 | 60 | ### |
54 | 61 | # Cassette fixes |
@@ -88,6 +95,3 @@ for f in (:cos, :cospi, :sin, :sinpi, :tan, |
88 | 95 | return CUDAnative.$f(x) |
89 | 96 | end |
90 | 97 | end |
91 | | - |
92 | | -contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...) |
93 | | - |
|
0 commit comments