Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit f014b84

Browse files
committed
Make it possible to control the 265-hack.
1 parent a1a1a88 commit f014b84

File tree

3 files changed

+48
-40
lines changed

3 files changed

+48
-40
lines changed

src/context.jl

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
##
2-
# Implements contextual dispatch through Cassette.jl
3-
# Goals:
4-
# - Rewrite common CPU functions to appropriate GPU intrinsics
5-
#
1+
# contextual dispatch using Cassette.jl
2+
#
63
# TODO:
74
# - error (erf, ...)
85
# - pow
@@ -15,40 +12,50 @@
1512

1613
using Cassette
1714

18-
@inline function unknowably_false()
19-
Base.llvmcall("ret i8 0", Bool, Tuple{})
20-
end
15+
@inline unknowably_false() = Base.llvmcall("ret i8 0", Bool, Tuple{})
16+
17+
function generate_transform(method_redefinitions)
18+
return function transform(ctx, ref)
19+
CI = ref.code_info
2120

22-
function transform(ctx, ref)
23-
CI = ref.code_info
24-
noinline = any(@nospecialize(x) ->
25-
Core.Compiler.isexpr(x, :meta) &&
26-
x.args[1] == :noinline,
27-
CI.code)
28-
CI.inlineable = !noinline
21+
# inline everything
22+
noinline = any(@nospecialize(x) ->
23+
Core.Compiler.isexpr(x, :meta) &&
24+
x.args[1] == :noinline,
25+
CI.code)
26+
CI.inlineable = !noinline
2927

30-
if isinteractive()
31-
# 265 fix, insert a call to the original method
32-
# that we later will remove with LLVM's DCE
33-
# TODO: We also don't want to compile these functions
34-
unknowably_false = GlobalRef(@__MODULE__, :unknowably_false)
35-
Cassette.insert_statements!(CI.code, CI.codelocs,
36-
(x, i) -> i == 1 ? 4 : nothing,
37-
(x, i) -> i == 1 ? [
38-
Expr(:call, Expr(:nooverdub, unknowably_false)),
39-
Expr(:gotoifnot, Core.SSAValue(i), i+3),
40-
Expr(:call, Expr(:nooverdub, Core.SlotNumber(1)), (Core.SlotNumber(i) for i in 2:ref.method.nargs)...),
41-
x] : nothing)
28+
if method_redefinitions
29+
# 265 fix, insert a call to the original method
30+
# that we later will remove with LLVM's DCE
31+
# TODO: We also don't want to compile these functions
32+
unknowably_false = GlobalRef(@__MODULE__, :unknowably_false)
33+
Cassette.insert_statements!(CI.code, CI.codelocs,
34+
(x, i) -> i == 1 ? 4 : nothing,
35+
(x, i) -> i == 1 ? [
36+
Expr(:call, Expr(:nooverdub, unknowably_false)),
37+
Expr(:gotoifnot, Core.SSAValue(i), i+3),
38+
Expr(:call, Expr(:nooverdub, Core.SlotNumber(1)), (Core.SlotNumber(i) for i in 2:ref.method.nargs)...),
39+
x] : nothing)
40+
end
41+
CI.ssavaluetypes = length(CI.code)
42+
43+
#Core.Compiler.validate_code(CI)
44+
return CI
4245
end
43-
CI.ssavaluetypes = length(CI.code)
44-
# Core.Compiler.validate_code(CI)
45-
return CI
4646
end
4747

48-
const InlinePass = Cassette.@pass transform
48+
const StaticPass = Cassette.@pass generate_transform(false)
49+
const InteractivePass = Cassette.@pass generate_transform(true)
4950

5051
Cassette.@context CUDACtx
51-
const cudactx = Cassette.disablehooks(CUDACtx(pass = InlinePass))
52+
const StaticCtx = Cassette.disablehooks(CUDACtx(pass = StaticPass))
53+
const InteractiveCtx = Cassette.disablehooks(CUDACtx(pass = InteractivePass))
54+
55+
@inline function contextualize(f::F, interactive) where F
56+
ctx = interactive ? InteractiveCtx : StaticCtx
57+
(args...) -> Cassette.overdub(ctx, f, args...)
58+
end
5259

5360
###
5461
# Cassette fixes
@@ -88,6 +95,3 @@ for f in (:cos, :cospi, :sin, :sinpi, :tan,
8895
return CUDAnative.$f(x)
8996
end
9097
end
91-
92-
contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)
93-

src/execution.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ export @cuda, cudaconvert, cufunction, dynamic_cufunction, nearest_warpsize
88
# split keyword arguments to `@cuda` into ones affecting the macro itself, the compiler and
99
# the code it generates, or the execution
1010
function split_kwargs(kwargs)
11-
macro_kws = [:dynamic]
11+
macro_kws = [:dynamic, :interactive]
1212
compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs, :name]
1313
call_kws = [:cooperative, :blocks, :threads, :config, :shmem, :stream]
1414
macro_kwargs = []
@@ -138,11 +138,15 @@ macro cuda(ex...)
138138

139139
# handle keyword arguments that influence the macro's behavior
140140
dynamic = false
141+
interactive = isinteractive()
141142
for kwarg in macro_kwargs
142143
key,val = kwarg.args
143144
if key == :dynamic
144145
isa(val, Bool) || throw(ArgumentError("`dynamic` keyword argument to @cuda should be a constant value"))
145146
dynamic = val::Bool
147+
elseif key == :interactive
148+
isa(val, Bool) || throw(ArgumentError("`interactive` keyword argument to @cuda should be a constant value"))
149+
interactive = val::Bool
146150
else
147151
throw(ArgumentError("Unsupported keyword argument '$key'"))
148152
end
@@ -159,7 +163,7 @@ macro cuda(ex...)
159163
quote
160164
# we're in kernel land already, so no need to cudaconvert arguments
161165
local kernel_tt = Tuple{$((:(Core.Typeof($var)) for var in var_exprs)...)}
162-
local kernel_f = contextualize($(esc(f)))
166+
local kernel_f = contextualize($(esc(f)), $interactive)
163167
local kernel = dynamic_cufunction(kernel_f, kernel_tt)
164168
kernel($(var_exprs...); $(map(esc, call_kwargs)...))
165169
end)
@@ -173,7 +177,7 @@ macro cuda(ex...)
173177
GC.@preserve $(vars...) begin
174178
local kernel_args = cudaconvert.(($(var_exprs...),))
175179
local kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
176-
local kernel_f = contextualize($(esc(f)))
180+
local kernel_f = contextualize($(esc(f)), $interactive)
177181
local kernel = cufunction(kernel_f, kernel_tt;
178182
$(map(esc, compiler_kwargs)...))
179183
kernel(kernel_args...; $(map(esc, call_kwargs)...))

test/device/execution.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,15 +308,15 @@ end
308308
return
309309
end
310310

311-
@cuda kernel(convert(CuPtr{Int}, arr.buf))
311+
@cuda interactive=true kernel(convert(CuPtr{Int}, arr.buf))
312312
@test Array(arr)[] == 1
313313

314314
function kernel(ptr)
315315
unsafe_store!(ptr, 2)
316316
return
317317
end
318318

319-
@cuda kernel(convert(CuPtr{Int}, arr.buf))
319+
@cuda interactive=true kernel(convert(CuPtr{Int}, arr.buf))
320320
@test Array(arr)[] == 2
321321
end
322322

0 commit comments

Comments
 (0)