@@ -12,7 +12,7 @@ const ws = Int32(32)
1212
1313# "two packed values specifying a mask for logically splitting warps into sub-segments
1414# and an upper bound for clamping the source lane index"
15- @inline pack (width:: UInt32 , mask:: UInt32 ) :: UInt32 = (convert (UInt32, ws - width) << 8 ) | mask
15+ @inline pack (width, mask) = (convert (UInt32, ws - width) << 8 ) | convert (UInt32, mask)
1616
1717# NOTE: CUDA C disagrees with PTX on how shuffles are called
1818for (name, mode, mask, offset) in ((" _up" , :up , UInt32 (0x00 ), src-> src),
@@ -28,35 +28,32 @@ for (name, mode, mask, offset) in (("_up", :up, UInt32(0x00), src->src),
2828 intrinsic = " llvm.nvvm.shfl.sync.$mode .i32"
2929
3030 fname_sync = Symbol (" $(fname) _sync" )
31+ __fname_sync = Symbol (" __$(fname) _sync" )
3132 @eval begin
3233 export $ fname_sync
3334
34- @inline $ fname_sync (mask:: UInt32 , val:: UInt32 , src:: UInt32 , width:: UInt32 = $ ws) =
35+ # HACK: recurse_value_invocation and friends split the first argument of a call,
36+ # so swap mask and val for these tools to works.
37+ @inline $ fname_sync (mask, val, src, width= $ ws) =
38+ $ __fname_sync (val, mask, src, width)
39+ @inline $ __fname_sync (val:: UInt32 , mask, src, width) =
3540 ccall ($ intrinsic, llvmcall, UInt32,
3641 (UInt32, UInt32, UInt32, UInt32),
3742 mask, val, $ (offset (:src )), pack (width, $ mask))
3843
39- # FIXME : replace this with a checked conversion once we have exceptions
40- @inline $ fname_sync (mask:: UInt32 , val:: UInt32 , src:: Integer , width:: Integer = $ ws) =
41- $ fname_sync (mask, val, unsafe_trunc (UInt32, src), unsafe_trunc (UInt32, width))
42-
4344 # for backwards compatibility, have the non-synchronizing intrinsic dispatch
4445 # to the synchronizing one (with a full-lane default value for the mask)
45- @inline $ fname (val:: UInt32 , src:: Integer , width:: Integer = $ ws, mask:: UInt32 = 0xffffffff ) =
46+ @inline $ fname (val:: UInt32 , src, width= $ ws, mask:: UInt32 = 0xffffffff ) =
4647 $ fname_sync (mask, val, src, width)
4748 end
4849 else
4950 intrinsic = " llvm.nvvm.shfl.$mode .i32"
5051
5152 @eval begin
52- @inline $ fname (val:: UInt32 , src:: UInt32 , width:: UInt32 = $ ws) =
53+ @inline $ fname (val:: UInt32 , src, width= $ ws) =
5354 ccall ($ intrinsic, llvmcall, UInt32,
5455 (UInt32, UInt32, UInt32),
5556 val, $ (offset (:src )), pack (width, $ mask))
56-
57- # FIXME : replace this with a checked conversion once we have exceptions
58- @inline $ fname (val:: UInt32 , src:: Integer , width:: Integer = $ ws) =
59- $ fname (val, unsafe_trunc (UInt32, src), unsafe_trunc (UInt32, width))
6057 end
6158 end
6259end
@@ -68,8 +65,8 @@ for name in ["_up", "_down", "_xor", ""]
6865 fname = Symbol (" shfl$name " )
6966 @eval @inline $ fname (src, args... ) = recurse_value_invocation ($ fname, src, args... )
7067
71- fname_sync = Symbol (" $(fname) _sync" )
72- @eval @inline $ fname_sync (src, args... ) = recurse_value_invocation ($ fname , src, args... )
68+ fname_sync = Symbol (" __ $(fname) _sync" )
69+ @eval @inline $ fname_sync (src, args... ) = recurse_value_invocation ($ fname_sync , src, args... )
7370end
7471
7572
0 commit comments