@@ -52,8 +52,8 @@ function Random.rand!(rng::RNG, A::AnyCuArray)
5252
5353 # grid-stride loop
5454 threadId = threadIdx (). x
55- window = blockDim (). x * gridDim (). x
56- offset = (blockIdx (). x - 1 ) * blockDim (). x
55+ window = widemul ( blockDim (). x, gridDim (). x)
56+ offset = widemul (blockIdx (). x - 1 i32, blockDim (). x)
5757 while offset < length (A)
5858 i = threadId + offset
5959 if i <= length (A)
@@ -96,8 +96,8 @@ function Random.randn!(rng::RNG, A::AnyCuArray{<:Union{AbstractFloat,Complex{<:A
9696
9797 # grid-stride loop
9898 threadId = threadIdx (). x
99- window = (blockDim (). x - 1 ) * gridDim (). x
100- offset = (blockIdx (). x - 1 ) * blockDim (). x
99+ window = widemul (blockDim (). x, gridDim (). x)
100+ offset = widemul (blockIdx (). x - 1 i32, blockDim (). x)
101101 while offset < length (A)
102102 i = threadId + offset
103103 j = threadId + offset + window
@@ -129,8 +129,8 @@ function Random.randn!(rng::RNG, A::AnyCuArray{<:Union{AbstractFloat,Complex{<:A
129129
130130 # grid-stride loop
131131 threadId = threadIdx (). x
132- window = (blockDim (). x - 1 ) * gridDim (). x
133- offset = (blockIdx (). x - 1 ) * blockDim (). x
132+ window = widemul (blockDim (). x, gridDim (). x)
133+ offset = widemul (blockIdx (). x - 1 i32, blockDim (). x)
134134 while offset < length (A)
135135 i = threadId + offset
136136 if i <= length (A)
@@ -150,11 +150,11 @@ function Random.randn!(rng::RNG, A::AnyCuArray{<:Union{AbstractFloat,Complex{<:A
150150 return
151151 end
152152
153- kernel = @cuda launch = false name = " rand!" kernel (A, rng . seed, rng . counter)
154- config = launch_configuration (kernel . fun; max_threads = 64 )
155- threads = max ( 32 , min (config . threads, length (A)÷ 2 ) )
156- blocks = min (config . blocks, cld ( cld ( length (A), 2 ), threads))
157- kernel (A, rng. seed, rng. counter; threads, blocks )
153+ # see note in ` rand!` about the launch configuration
154+ threads = 32
155+ blocks = cld ( cld ( length (A), 2 ), threads )
156+
157+ @cuda threads = threads blocks = blocks name = " randn! " kernel (A, rng. seed, rng. counter)
158158
159159 new_counter = Int64 (rng. counter) + length (A)
160160 overflow, remainder = fldmod (new_counter, typemax (UInt32))
0 commit comments