22
33# TODO : does not work on sub-word (ie. Int16) or non-word divisible sized types
44
5- # TODO : should shfl_idx conform to 1-based indexing?
6-
75# TODO : these functions should dispatch based on the actual warp size
86const ws = Int32 (32 )
97
@@ -14,52 +12,48 @@ const ws = Int32(32)
1412
1513# "two packed values specifying a mask for logically splitting warps into sub-segments
1614# and an upper bound for clamping the source lane index"
17- @inline pack (width:: UInt32 , mask:: UInt32 ) :: UInt32 = (convert (UInt32, ws - width) << 8 ) | mask
15+ @inline pack (width, mask) = (convert (UInt32, ws - width) << 8 ) | convert (UInt32, mask)
1816
1917# NOTE: CUDA C disagrees with PTX on how shuffles are called
20- for (name, mode, mask) in ((" _up" , :up , UInt32 (0x00 )),
21- (" _down" , :down , UInt32 (0x1f )),
22- (" _xor" , :bfly , UInt32 (0x1f )),
23- (" " , :idx , UInt32 (0x1f )))
18+ for (name, mode, mask, offset ) in ((" _up" , :up , UInt32 (0x00 ), src -> src ),
19+ (" _down" , :down , UInt32 (0x1f ), src -> src ),
20+ (" _xor" , :bfly , UInt32 (0x1f ), src -> src ),
21+ (" " , :idx , UInt32 (0x1f ), src -> :( $ src - 1 )))
2422 fname = Symbol (" shfl$name " )
23+ @eval export $ fname
2524
2625 if cuda_driver_version >= v " 9.0" && v " 6.0" in ptx_support
27- instruction = Symbol (" shfl.sync.$mode .b32" )
28- fname_sync = Symbol (" $(fname) _sync" )
29-
30- # TODO : implement using LLVM intrinsics when we have D38090
26+ # newer hardware/CUDA versions use synchronizing intrinsics, which take an extra
27+ # mask argument indicating which threads in the lane should be synchronized
28+ intrinsic = " llvm.nvvm.shfl.sync.$mode .i32"
3129
30+ fname_sync = Symbol (" $(fname) _sync" )
31+ __fname_sync = Symbol (" __$(fname) _sync" )
3232 @eval begin
33- export $ fname_sync, $ fname
34-
35- @inline $ fname_sync (val:: UInt32 , src:: UInt32 , width:: UInt32 = $ ws,
36- threadmask:: UInt32 = 0xffffffff ) =
37- @asmcall ($ " $instruction \$ 0, \$ 1, \$ 2, \$ 3, \$ 4;" , " =r,r,r,r,r" , true ,
38- UInt32, NTuple{4 ,UInt32},
39- val, src, pack (width, $ mask), threadmask)
40-
41- # FIXME : replace this with a checked conversion once we have exceptions
42- @inline $ fname_sync (val:: UInt32 , src:: Integer , width:: Integer = $ ws,
43- threadmask:: UInt32 = 0xffffffff ) =
44- $ fname_sync (val, unsafe_trunc (UInt32, src), unsafe_trunc (UInt32, width),
45- threadmask)
46-
47- @inline $ fname (val:: UInt32 , src:: Integer , width:: Integer = $ ws) =
48- $ fname_sync (val, src, width)
33+ export $ fname_sync
34+
35+ # HACK: recurse_value_invocation and friends split the first argument of a call,
36+ # so swap mask and val for these tools to works.
37+ @inline $ fname_sync (mask, val, src, width= $ ws) =
38+ $ __fname_sync (val, mask, src, width)
39+ @inline $ __fname_sync (val:: UInt32 , mask, src, width) =
40+ ccall ($ intrinsic, llvmcall, UInt32,
41+ (UInt32, UInt32, UInt32, UInt32),
42+ mask, val, $ (offset (:src )), pack (width, $ mask))
43+
44+ # for backwards compatibility, have the non-synchronizing intrinsic dispatch
45+ # to the synchronizing one (with a full-lane default value for the mask)
46+ @inline $ fname (val:: UInt32 , src, width= $ ws, mask:: UInt32 = 0xffffffff ) =
47+ $ fname_sync (mask, val, src, width)
4948 end
5049 else
51- intrinsic = Symbol ( " llvm.nvvm.shfl.$mode .i32" )
50+ intrinsic = " llvm.nvvm.shfl.$mode .i32"
5251
5352 @eval begin
54- export $ fname
55- @inline $ fname (val:: UInt32 , src:: UInt32 , width:: UInt32 = $ ws) =
56- ccall ($ " $intrinsic " , llvmcall, UInt32,
53+ @inline $ fname (val:: UInt32 , src, width= $ ws) =
54+ ccall ($ intrinsic, llvmcall, UInt32,
5755 (UInt32, UInt32, UInt32),
58- val, src, pack (width, $ mask))
59-
60- # FIXME : replace this with a checked conversion once we have exceptions
61- @inline $ fname (val:: UInt32 , src:: Integer , width:: Integer = $ ws) =
62- $ fname (val, unsafe_trunc (UInt32, src), unsafe_trunc (UInt32, width))
56+ val, $ (offset (:src )), pack (width, $ mask))
6357 end
6458 end
6559end
@@ -71,62 +65,70 @@ for name in ["_up", "_down", "_xor", ""]
7165 fname = Symbol (" shfl$name " )
7266 @eval @inline $ fname (src, args... ) = recurse_value_invocation ($ fname, src, args... )
7367
74- fname_sync = Symbol (" $(fname) _sync" )
75- @eval @inline $ fname_sync (src, args... ) = recurse_value_invocation ($ fname , src, args... )
68+ fname_sync = Symbol (" __ $(fname) _sync" )
69+ @eval @inline $ fname_sync (src, args... ) = recurse_value_invocation ($ fname_sync , src, args... )
7670end
7771
7872
7973# documentation
8074
8175@doc """
82- shfl(val, lane::Integer, width::Integer=32)
76+ shfl(val, lane::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
8377
84- Shuffle a value from a directly indexed lane `lane`.
78+ Shuffle a value from a directly indexed lane `lane`. The argument `threadmask` for selecting
79+ which threads to synchronize is only available on recent hardware, and defaults to all
80+ threads in the warp.
8581""" shfl
8682
8783@doc """
88- shfl_up(val, delta::Integer, width::Integer=32)
84+ shfl_up(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
8985
90- Shuffle a value from a lane with lower ID relative to caller.
86+ Shuffle a value from a lane with lower ID relative to caller. The argument `threadmask` for
87+ selecting which threads to synchronize is only available on recent hardware, and defaults to
88+ all threads in the warp.
9189""" shfl_up
9290
9391@doc """
94- shfl_down(val, delta::Integer, width::Integer=32)
92+ shfl_down(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
9593
96- Shuffle a value from a lane with higher ID relative to caller.
94+ Shuffle a value from a lane with higher ID relative to caller. The argument `threadmask` for
95+ selecting which threads to synchronize is only available on recent hardware, and defaults to
96+ all threads in the warp.
9797""" shfl_down
9898
9999@doc """
100- shfl_xor(val, mask ::Integer, width::Integer=32)
100+ shfl_xor(val, lanemask ::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
101101
102- Shuffle a value from a lane based on bitwise XOR of own lane ID with `mask`.
102+ Shuffle a value from a lane based on bitwise XOR of own lane ID with `lanemask`. The
103+ argument `threadmask` for selecting which threads to synchronize is only available on recent
104+ hardware, and defaults to all threads in the warp.
103105""" shfl_xor
104106
105107
106108@doc """
107- shfl_sync(val, lane::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
109+ shfl_sync(threadmask::UInt32, val, lane::Integer, width::Integer=32)
108110
109- Shuffle a value from a directly indexed lane `lane`. The default value for `threadmask`
110- performs the shuffle on all threads in the warp .
111+ Shuffle a value from a directly indexed lane `lane`, and synchronize threads according to
112+ `threadmask` .
111113""" shfl_sync
112114
113115@doc """
114- shfl_up_sync(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
116+ shfl_up_sync(threadmask::UInt32, val, delta::Integer, width::Integer=32)
115117
116- Shuffle a value from a lane with lower ID relative to caller. The default value for
117- `threadmask` performs the shuffle on all threads in the warp .
118+ Shuffle a value from a lane with lower ID relative to caller, and synchronize threads
119+ according to `threadmask`.
118120""" shfl_up_sync
119121
120122@doc """
121- shfl_down_sync(val, delta::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
123+ shfl_down_sync(threadmask::UInt32, val, delta::Integer, width::Integer=32)
122124
123- Shuffle a value from a lane with higher ID relative to caller. The default value for
124- `threadmask` performs the shuffle on all threads in the warp .
125+ Shuffle a value from a lane with higher ID relative to caller, and synchronize threads
126+ according to `threadmask`.
125127""" shfl_down_sync
126128
127129@doc """
128- shfl_xor_sync(val, mask::Integer, width::Integer=32, threadmask::UInt32=0xffffffff )
130+ shfl_xor_sync(threadmask::UInt32, val, mask::Integer, width::Integer=32)
129131
130- Shuffle a value from a lane based on bitwise XOR of own lane ID with `mask`. The default
131- value for `threadmask` performs the shuffle on all threads in the warp .
132+ Shuffle a value from a lane based on bitwise XOR of own lane ID with `mask`, and synchronize
133+ threads according to `threadmask`.
132134""" shfl_xor_sync
0 commit comments