@@ -14,99 +14,6 @@ Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}) where {N,T} =
1414Base. similar (bc:: Broadcasted{CuArrayStyle{N}} , :: Type{T} , dims) where {N,T} =
1515 CuArray {T} (undef, dims)
1616
17-
18- # # replace base functions with libdevice alternatives
19-
20- cufunc (f) = f
21- cufunc (:: Type{T} ) where T = (x... ) -> T (x... ) # broadcasting type ctors isn't GPU compatible
22-
23- Broadcast. broadcasted (:: CuArrayStyle{N} , f, args... ) where {N} =
24- Broadcasted {CuArrayStyle{N}} (cufunc (f), args, nothing )
25-
26- const device_intrinsics = :[
27- cos, cospi, sin, sinpi, tan, acos, asin, atan,
28- cosh, sinh, tanh, acosh, asinh, atanh, angle,
29- log, log10, log1p, log2, logb, ilogb,
30- exp, exp2, exp10, expm1, ldexp,
31- erf, erfinv, erfc, erfcinv, erfcx,
32- brev, clz, ffs, byte_perm, popc,
33- isfinite, isinf, isnan, nearbyint,
34- nextafter, signbit, copysign, abs,
35- sqrt, rsqrt, cbrt, rcbrt, pow,
36- ceil, floor, saturate,
37- lgamma, tgamma,
38- j0, j1, jn, y0, y1, yn,
39- normcdf, normcdfinv, hypot,
40- fma, sad, dim, mul24, mul64hi, hadd, rhadd, scalbn]. args
41-
42- for f in device_intrinsics
43- isdefined (Base, f) || continue
44- @eval cufunc (:: typeof (Base.$ f)) = $ f
45- end
46-
47- # broadcast ^
48-
49- culiteral_pow (:: typeof (^ ), x:: T , :: Val{0} ) where {T<: Real } = one (x)
50- culiteral_pow (:: typeof (^ ), x:: T , :: Val{1} ) where {T<: Real } = x
51- culiteral_pow (:: typeof (^ ), x:: T , :: Val{2} ) where {T<: Real } = x * x
52- culiteral_pow (:: typeof (^ ), x:: T , :: Val{3} ) where {T<: Real } = x * x * x
53- culiteral_pow (:: typeof (^ ), x:: T , :: Val{p} ) where {T<: Real ,p} = pow (x, Int32 (p))
54-
55- cufunc (:: typeof (Base. literal_pow)) = culiteral_pow
56- cufunc (:: typeof (Base.:(^ ))) = pow
57-
58- using MacroTools
59-
60- const _cufuncs = [copy (device_intrinsics); :^ ]
61- cufuncs () = (global _cufuncs; _cufuncs)
62-
63- _cuint (x:: Int ) = Int32 (x)
64- _cuint (x:: Expr ) = x. head == :call && x. args[1 ] == :Int32 && x. args[2 ] isa Int ? Int32 (x. args[2 ]) : x
65- _cuint (x) = x
66-
67- function _cupowliteral (x:: Expr )
68- if x. head == :call && x. args[1 ] == :(CUDA. cufunc (^ )) && x. args[3 ] isa Int32
69- num = x. args[3 ]
70- if 0 <= num <= 3
71- sym = gensym (:x )
72- new_x = Expr (:block , :($ sym = $ (x. args[2 ])))
73-
74- if iszero (num)
75- push! (new_x. args, :(one ($ sym)))
76- else
77- unroll = Expr (:call , :* )
78- for x = one (num): num
79- push! (unroll. args, sym)
80- end
81- push! (new_x. args, unroll)
82- end
83-
84- x = new_x
85- end
86- end
87- x
88- end
89- _cupowliteral (x) = x
90-
91- function replace_device (ex)
92- global _cufuncs
93- MacroTools. postwalk (ex) do x
94- x = x in _cufuncs ? :(CUDA. cufunc ($ x)) : x
95- x = _cuint (x)
96- x = _cupowliteral (x)
97- x
98- end
99- end
100-
101- macro cufunc (ex)
102- global _cufuncs
103- def = MacroTools. splitdef (ex)
104- f = def[:name ]
105- def[:name ] = Symbol (:cu , f)
106- def[:body ] = replace_device (def[:body ])
107- push! (_cufuncs, f)
108- quote
109- $ (esc (MacroTools. combinedef (def)))
110- CUDA. cufunc (:: typeof ($ (esc (f)))) = $ (esc (def[:name ]))
111- end
112- end
17+ # broadcasting type ctors isn't GPU compatible
18+ Broadcast. broadcasted (:: CuArrayStyle{N} , f:: Type{T} , args... ) where {N, T} =
19+ Broadcasted {CuArrayStyle{N}} ((x... ) -> T (x... ), args, nothing )
0 commit comments