@@ -25,32 +25,34 @@ Base.:(*)(p::ScaledPlan, x::DenseCuArray) = rmul!(p.p * x, p.scale)
2525
2626# N is the number of dimensions
2727
28- mutable struct CuFFTPlan{T<: cufftNumber ,S<: cufftNumber ,K,inplace,N} <: Plan{S}
28+ mutable struct CuFFTPlan{T<: cufftNumber ,S<: cufftNumber ,K,inplace,N,R,B } <: Plan{S}
2929 # handle to Cuda low level plan. Note that this plan sometimes has lower dimensions
3030 # to handle more transform cases such as individual directions
3131 handle:: cufftHandle
3232 ctx:: CuContext
3333 stream:: CuStream
3434 input_size:: NTuple{N,Int} # Julia size of input array
3535 output_size:: NTuple{N,Int} # Julia size of output array
36- region:: Any
36+ region:: NTuple{R,Int}
37+ buffer:: B # buffer for out-of-place complex-to-real FFT, or `nothing` if not needed
3738 pinv:: ScaledPlan{T} # required by AbstractFFTs API, will be defined by AbstractFFTs if needed
3839
39- function CuFFTPlan {T,S,K,inplace,N} (handle:: cufftHandle ,
40- input_size:: NTuple{N,Int} , output_size:: NTuple{N,Int} , region
41- ) where {T<: cufftNumber ,S<: cufftNumber ,K,inplace,N}
40+ function CuFFTPlan {T,S,K,inplace,N,R,B} (handle:: cufftHandle ,
41+ input_size:: NTuple{N,Int} , output_size:: NTuple{N,Int} ,
42+ region:: NTuple{R,Int} , buffer:: B
43+ ) where {T<: cufftNumber ,S<: cufftNumber ,K,inplace,N,R,B}
4244 abs (K) == 1 || throw (ArgumentError (" FFT direction must be either -1 (forward) or +1 (inverse)" ))
4345 inplace isa Bool || throw (ArgumentError (" FFT inplace argument must be a Bool" ))
44- p = new {T,S,K,inplace,N} (handle, context (), stream (), input_size, output_size, region)
46+ p = new {T,S,K,inplace,N,R,B } (handle, context (), stream (), input_size, output_size, region, buffer )
4547 finalizer (unsafe_free!, p)
4648 p
4749 end
4850end
4951
50- function CuFFTPlan {T,S,K,inplace,N} (handle:: cufftHandle , X:: DenseCuArray{S,N} ,
51- sizey:: NTuple{N,Int} , region,
52- ) where {T<: cufftNumber ,S<: cufftNumber ,K,inplace,N}
53- CuFFTPlan {T,S,K,inplace,N} (handle, size (X), sizey, region)
52+ function CuFFTPlan {T,S,K,inplace,N,R,B } (handle:: cufftHandle , X:: DenseCuArray{S,N} ,
53+ sizey:: NTuple{N,Int} , region:: NTuple{R,Int} , buffer :: B
54+ ) where {T<: cufftNumber ,S<: cufftNumber ,K,inplace,N,R,B }
55+ CuFFTPlan {T,S,K,inplace,N,R,B } (handle, size (X), sizey, region, buffer )
5456end
5557
5658function CUDA. unsafe_free! (plan:: CuFFTPlan )
@@ -60,6 +62,9 @@ function CUDA.unsafe_free!(plan::CuFFTPlan)
6062 end
6163 plan. handle = C_NULL
6264 end
65+ if ! isnothing (plan. buffer)
66+ CUDA. unsafe_free! (plan. buffer)
67+ end
6368end
6469
6570function showfftdims (io, sz, T)
@@ -151,103 +156,116 @@ end
151156function plan_fft! (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
152157 K = CUFFT_FORWARD
153158 inplace = true
154- region = Tuple (region)
159+ R = length (region)
160+ region = NTuple {R,Int} (region)
155161
156162 md = plan_max_dims (region, size (X))
157163 sizex = size (X)[1 : md]
158164 handle = cufftGetPlan (T, T, sizex, region)
159165
160- CuFFTPlan {T,T,K,inplace,N} (handle, X, size (X), region)
166+ CuFFTPlan {T,T,K,inplace,N,R,Nothing } (handle, X, size (X), region, nothing )
161167end
162168
163-
164169function plan_bfft! (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
165170 K = CUFFT_INVERSE
166171 inplace = true
167- region = Tuple (region)
172+ R = length (region)
173+ region = NTuple {R,Int} (region)
168174
169175 md = plan_max_dims (region, size (X))
170176 sizex = size (X)[1 : md]
171177 handle = cufftGetPlan (T, T, sizex, region)
172178
173- CuFFTPlan {T,T,K,inplace,N} (handle, X, size (X), region)
179+ CuFFTPlan {T,T,K,inplace,N,R,Nothing } (handle, X, size (X), region, nothing )
174180end
175181
176182# out-of-place complex
177183function plan_fft (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
178184 K = CUFFT_FORWARD
179185 inplace = false
180- region = Tuple (region)
186+ R = length (region)
187+ region = NTuple {R,Int} (region)
181188
182189 md = plan_max_dims (region,size (X))
183190 sizex = size (X)[1 : md]
184191 handle = cufftGetPlan (T, T, sizex, region)
185192
186- CuFFTPlan {T,T,K,inplace,N} (handle, X, size (X), region)
193+ CuFFTPlan {T,T,K,inplace,N,R,Nothing } (handle, X, size (X), region, nothing )
187194end
188195
189196function plan_bfft (X:: DenseCuArray{T,N} , region) where {T<: cufftComplexes ,N}
190197 K = CUFFT_INVERSE
191198 inplace = false
192- region = Tuple (region)
199+ R = length (region)
200+ region = NTuple {R,Int} (region)
193201
194202 md = plan_max_dims (region,size (X))
195203 sizex = size (X)[1 : md]
196204 handle = cufftGetPlan (T, T, sizex, region)
197205
198- CuFFTPlan {T,T,K,inplace,N} (handle, size (X), size (X), region)
206+ CuFFTPlan {T,T,K,inplace,N,R,Nothing } (handle, size (X), size (X), region, nothing )
199207end
200208
201209# out-of-place real-to-complex
202210function plan_rfft (X:: DenseCuArray{T,N} , region) where {T<: cufftReals ,N}
203211 K = CUFFT_FORWARD
204212 inplace = false
205- region = Tuple (region)
213+ R = length (region)
214+ region = NTuple {R,Int} (region)
206215
207216 md = plan_max_dims (region,size (X))
208- # X = front_view(X, md)
209217 sizex = size (X)[1 : md]
210218
211219 handle = cufftGetPlan (complex (T), T, sizex, region)
212220
213221 ydims = collect (size (X))
214- ydims[region[1 ]] = div (ydims[region[1 ]],2 ) + 1
222+ ydims[region[1 ]] = div (ydims[region[1 ]], 2 ) + 1
215223
216- CuFFTPlan {complex(T),T,K,inplace,N} (handle, size (X), (ydims... ,), region)
224+ # The buffer is not needed for real-to-complex (`mul!`),
225+ # but it’s required for complex-to-real (`ldiv!`).
226+ buffer = CuArray {complex(T)} (undef, ydims... )
227+ B = typeof (buffer)
228+
229+ CuFFTPlan {complex(T),T,K,inplace,N,R,B} (handle, size (X), (ydims... ,), region, buffer)
217230end
218231
219- function plan_brfft (X:: DenseCuArray{T,N} , d:: Integer , region:: Any ) where {T<: cufftComplexes ,N}
232+ # out-of-place complex-to-real
233+ function plan_brfft (X:: DenseCuArray{T,N} , d:: Integer , region) where {T<: cufftComplexes ,N}
220234 K = CUFFT_INVERSE
221235 inplace = false
222- region = Tuple (region)
236+ R = length (region)
237+ region = NTuple {R,Int} (region)
223238
224239 ydims = collect (size (X))
225240 ydims[region[1 ]] = d
226241
227242 handle = cufftGetPlan (real (T), T, (ydims... ,), region)
228243
229- CuFFTPlan {real(T),T,K,inplace,N} (handle, size (X), (ydims... ,), region)
244+ buffer = CuArray {T} (undef, size (X))
245+ B = typeof (buffer)
246+
247+ CuFFTPlan {real(T),T,K,inplace,N,R,B} (handle, size (X), (ydims... ,), region, buffer)
230248end
231249
232250
233251# FIXME : plan_inv methods allocate needlessly (to provide type parameters)
234252# Perhaps use FakeArray types to avoid this.
235253
236- function plan_inv (p:: CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N}
237- ) where {T<: cufftNumber ,S<: cufftNumber ,N,inplace }
254+ function plan_inv (p:: CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N,R,B }
255+ ) where {T<: cufftNumber ,S<: cufftNumber ,inplace,N,R,B }
238256 md_osz = plan_max_dims (p. region, p. output_size)
239257 sz_X = p. output_size[1 : md_osz]
240258 handle = cufftGetPlan (S, T, sz_X, p. region)
241- ScaledPlan (CuFFTPlan {S,T,CUFFT_FORWARD,inplace,N} (handle, p. output_size, p. input_size, p. region),
259+ ScaledPlan (CuFFTPlan {S,T,CUFFT_FORWARD,inplace,N,R,B } (handle, p. output_size, p. input_size, p. region, p . buffer ),
242260 normalization (real (T), p. output_size, p. region))
243261end
244262
245- function plan_inv (p:: CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N}
246- ) where {T<: cufftNumber ,S<: cufftNumber ,N,inplace }
263+ function plan_inv (p:: CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N,R,B }
264+ ) where {T<: cufftNumber ,S<: cufftNumber ,inplace,N,R,B }
247265 md_isz = plan_max_dims (p. region, p. input_size)
248266 sz_Y = p. input_size[1 : md_isz]
249267 handle = cufftGetPlan (S, T, sz_Y, p. region)
250- ScaledPlan (CuFFTPlan {S,T,CUFFT_INVERSE,inplace,N} (handle, p. output_size, p. input_size, p. region),
268+ ScaledPlan (CuFFTPlan {S,T,CUFFT_INVERSE,inplace,N,R,B } (handle, p. output_size, p. input_size, p. region, p . buffer ),
251269 normalization (real (S), p. input_size, p. region))
252270end
253271
@@ -309,10 +327,14 @@ function LinearAlgebra.mul!(y::DenseCuArray{T}, p::CuFFTPlan{T,S,K,inplace}, x::
309327 ) where {T,S,K,inplace}
310328 assert_applicable (p, x, y)
311329 if ! inplace && T<: Real
312- # Out-of-place complex-to-real FFT will always overwrite input buffer.
313- x = copy (x)
330+ # Out-of-place complex-to-real FFT will always overwrite input x.
331+ # We copy the input x in an auxiliary buffer.
332+ z = p. buffer
333+ copyto! (z, x)
334+ else
335+ z = x
314336 end
315- unsafe_execute_trailing! (p, x , y)
337+ unsafe_execute_trailing! (p, z , y)
316338 y
317339end
318340
@@ -323,13 +345,21 @@ function Base.:(*)(p::CuFFTPlan{T,S,K,true}, x::DenseCuArray{S}) where {T,S,K}
323345end
324346
325347function Base.:(* )(p:: CuFFTPlan{T,S,K,false} , x:: DenseCuArray{S1,M} ) where {T,S,K,S1,M}
326- if S1 != S || T<: Real
327- # Convert to the expected input type. Also,
328- # Out-of-place complex-to-real FFT will always overwrite input buffer.
329- x = copy1 (S, x)
348+ if T<: Real
349+ # Out-of-place complex-to-real FFT will always overwrite input x.
350+ # We copy the input x in an auxiliary buffer.
351+ z = p. buffer
352+ copyto! (z, x)
353+ else
354+ if S1 != S
355+ # Convert to the expected input type.
356+ z = copy1 (S, x)
357+ else
358+ z = x
359+ end
330360 end
331- assert_applicable (p, x )
361+ assert_applicable (p, z )
332362 y = CuArray {T,M} (undef, p. output_size)
333- unsafe_execute_trailing! (p, x , y)
363+ unsafe_execute_trailing! (p, z , y)
334364 y
335365end
0 commit comments