Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit c95ad5e

Browse files
committed
Fix invocation of shfl_sync on wide or aggregate inputs.
Fixes #420
1 parent 82c123d commit c95ad5e

File tree

2 files changed

+41
-21
lines changed

2 files changed

+41
-21
lines changed

src/device/cuda/warp_shuffle.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const ws = Int32(32)
1212

1313
# "two packed values specifying a mask for logically splitting warps into sub-segments
1414
# and an upper bound for clamping the source lane index"
15-
@inline pack(width::UInt32, mask::UInt32)::UInt32 = (convert(UInt32, ws - width) << 8) | mask
15+
@inline pack(width, mask) = (convert(UInt32, ws - width) << 8) | convert(UInt32, mask)
1616

1717
# NOTE: CUDA C disagrees with PTX on how shuffles are called
1818
for (name, mode, mask, offset) in (("_up", :up, UInt32(0x00), src->src),
@@ -28,35 +28,32 @@ for (name, mode, mask, offset) in (("_up", :up, UInt32(0x00), src->src),
2828
intrinsic = "llvm.nvvm.shfl.sync.$mode.i32"
2929

3030
fname_sync = Symbol("$(fname)_sync")
31+
__fname_sync = Symbol("__$(fname)_sync")
3132
@eval begin
3233
export $fname_sync
3334

34-
@inline $fname_sync(mask::UInt32, val::UInt32, src::UInt32, width::UInt32=$ws) =
35+
# HACK: recurse_value_invocation and friends split the first argument of a call,
36+
# so swap mask and val for these tools to works.
37+
@inline $fname_sync(mask, val, src, width=$ws) =
38+
$__fname_sync(val, mask, src, width)
39+
@inline $__fname_sync(val::UInt32, mask, src, width) =
3540
ccall($intrinsic, llvmcall, UInt32,
3641
(UInt32, UInt32, UInt32, UInt32),
3742
mask, val, $(offset(:src)), pack(width, $mask))
3843

39-
# FIXME: replace this with a checked conversion once we have exceptions
40-
@inline $fname_sync(mask::UInt32, val::UInt32, src::Integer, width::Integer=$ws) =
41-
$fname_sync(mask, val, unsafe_trunc(UInt32, src), unsafe_trunc(UInt32, width))
42-
4344
# for backwards compatibility, have the non-synchronizing intrinsic dispatch
4445
# to the synchronizing one (with a full-lane default value for the mask)
45-
@inline $fname(val::UInt32, src::Integer, width::Integer=$ws, mask::UInt32=0xffffffff) =
46+
@inline $fname(val::UInt32, src, width=$ws, mask::UInt32=0xffffffff) =
4647
$fname_sync(mask, val, src, width)
4748
end
4849
else
4950
intrinsic = "llvm.nvvm.shfl.$mode.i32"
5051

5152
@eval begin
52-
@inline $fname(val::UInt32, src::UInt32, width::UInt32=$ws) =
53+
@inline $fname(val::UInt32, src, width=$ws) =
5354
ccall($intrinsic, llvmcall, UInt32,
5455
(UInt32, UInt32, UInt32),
5556
val, $(offset(:src)), pack(width, $mask))
56-
57-
# FIXME: replace this with a checked conversion once we have exceptions
58-
@inline $fname(val::UInt32, src::Integer, width::Integer=$ws) =
59-
$fname(val, unsafe_trunc(UInt32, src), unsafe_trunc(UInt32, width))
6057
end
6158
end
6259
end
@@ -68,8 +65,8 @@ for name in ["_up", "_down", "_xor", ""]
6865
fname = Symbol("shfl$name")
6966
@eval @inline $fname(src, args...) = recurse_value_invocation($fname, src, args...)
7067

71-
fname_sync = Symbol("$(fname)_sync")
72-
@eval @inline $fname_sync(src, args...) = recurse_value_invocation($fname, src, args...)
68+
fname_sync = Symbol("__$(fname)_sync")
69+
@eval @inline $fname_sync(src, args...) = recurse_value_invocation($fname_sync, src, args...)
7370
end
7471

7572

test/device/cuda.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -552,15 +552,38 @@ end
552552

553553
n = 14
554554

555-
@testset for T in [Int32, Int64, Float32, Float64, AddableTuple]
556-
function kernel(d::CuDeviceArray{T}, n) where {T}
557-
t = threadIdx().x
558-
if t <= n
559-
d[t] += shfl_down(d[t], n÷2)
560-
end
561-
return
555+
function kernel1(d::CuDeviceArray{T}, n) where {T}
556+
t = threadIdx().x
557+
if t <= n
558+
d[t] += shfl_down(d[t], n÷2)
559+
end
560+
return
561+
end
562+
563+
function kernel2(d::CuDeviceArray{T}, n) where {T}
564+
t = threadIdx().x
565+
if t <= n
566+
d[t] += shfl_down(d[t], n÷2, 32, 0xffffffff)
562567
end
568+
return
569+
end
570+
571+
function kernel3(d::CuDeviceArray{T}, n) where {T}
572+
t = threadIdx().x
573+
if t <= n
574+
d[t] += shfl_down_sync(0xffffffff, d[t], n÷2, 32)
575+
end
576+
return
577+
end
578+
579+
kernels = try
580+
getfield(CUDAnative, :shfl_sync)
581+
(kernel1, kernel2, kernel3)
582+
catch
583+
(kernel1,)
584+
end
563585

586+
@testset for T in [Int32, Int64, Float32, Float64, AddableTuple], kernel in kernels
564587
a = T[T(i) for i in 1:n]
565588
d_a = CuArray(a)
566589

0 commit comments

Comments
 (0)