@@ -262,23 +262,43 @@ def _iter_fftnd(
262262 axes = None ,
263263 out = None ,
264264 direction = + 1 ,
265- overwrite_x = False ,
266- scale_function = lambda n , ind : 1.0 ,
265+ scale_function = lambda ind : 1.0 ,
267266):
268267 a = np .asarray (a )
269268 s , axes = _init_nd_shape_and_axes (a , s , axes )
270- ovwr = overwrite_x
271- for ii in reversed (range (len (axes ))):
269+
270+ # Combine the two, but in reverse, to end with the first axis given.
271+ axes_and_s = list (zip (axes , s ))[::- 1 ]
272+ # We try to use in-place calculations where possible, which is
273+ # everywhere except when the size changes after the first FFT.
274+ size_changes = [axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n ]
275+
276+ # If there are any size changes, we cannot use out
277+ res = None if size_changes else out
278+ for ind , (axis , n ) in enumerate (axes_and_s ):
279+ if axis in size_changes :
280+ if axis == size_changes [- 1 ]:
281+ # Last size change, so any output should now be OK
282+ # (an error will be raised if not), and if no output is
283+ # required, we want a freshly allocated array of the right size.
284+ res = out
285+ elif res is not None and n < res .shape [axis ]:
286+ # For an intermediate step where we return fewer elements, we
287+ # can use a smaller view of the previous array.
288+ res = res [(slice (None ),) * axis + (slice (n ),)]
289+ else :
290+ # If we need more elements, we cannot use res.
291+ res = None
272292 a = _c2c_fft1d_impl (
273293 a ,
274- n = s [ii ],
275- axis = axes [ii ],
276- overwrite_x = ovwr ,
294+ n = n ,
295+ axis = axis ,
277296 direction = direction ,
278- fsc = scale_function (s [ ii ], ii ),
279- out = out ,
297+ fsc = scale_function (ind ),
298+ out = res ,
280299 )
281- ovwr = True
300+ # Default output for next iteration.
301+ res = a
282302 return a
283303
284304
@@ -360,7 +380,6 @@ def _c2c_fftnd_impl(
360380 x ,
361381 s = None ,
362382 axes = None ,
363- overwrite_x = False ,
364383 direction = + 1 ,
365384 fsc = 1.0 ,
366385 out = None ,
@@ -385,7 +404,6 @@ def _c2c_fftnd_impl(
385404 if _direct :
386405 return _direct_fftnd (
387406 x ,
388- overwrite_x = overwrite_x ,
389407 direction = direction ,
390408 fsc = fsc ,
391409 out = out ,
@@ -403,11 +421,7 @@ def _c2c_fftnd_impl(
403421 x ,
404422 axes ,
405423 _direct_fftnd ,
406- {
407- "overwrite_x" : overwrite_x ,
408- "direction" : direction ,
409- "fsc" : fsc ,
410- },
424+ {"direction" : direction , "fsc" : fsc },
411425 res ,
412426 )
413427 else :
@@ -418,97 +432,121 @@ def _c2c_fftnd_impl(
418432 axes = axes ,
419433 out = out ,
420434 direction = direction ,
421- overwrite_x = overwrite_x ,
422- scale_function = lambda n , i : fsc if i == 0 else 1.0 ,
435+ scale_function = lambda i : fsc if i == 0 else 1.0 ,
423436 )
424437
425438
426439def _r2c_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
427440 a = np .asarray (x )
428441 no_trim = (s is None ) and (axes is None )
429442 s , axes = _cook_nd_args (a , s , axes )
443+ axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
430444 la = axes [- 1 ]
445+
431446 # trim array, so that rfft avoids doing unnecessary computations
432447 if not no_trim :
433448 a = _trim_array (a , s , axes )
449+
450+ # last axis is not included since we calculate FT sepaartely and it does not come in loop
451+ axes_and_s = list (zip (axes , s ))[- 2 ::- 1 ]
452+ size_changes = [axis for axis , n in axes_and_s if a .shape [axis ] != n ]
453+ res = None if size_changes else out
454+
434455 # r2c along last axis
435- a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = out )
456+ a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = res )
457+ res = a
436458 if len (s ) > 1 :
437- if not no_trim :
438- ss = list (s )
439- ss [- 1 ] = a .shape [la ]
440- a = _pad_array (a , tuple (ss ), axes )
459+
441460 len_axes = len (axes )
442461 if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
462+ if not no_trim :
463+ ss = list (s )
464+ ss [- 1 ] = a .shape [la ]
465+ a = _pad_array (a , tuple (ss ), axes )
443466 # a series of ND c2c FFTs along last axis
444467 ss , aa = _remove_axis (s , axes , - 1 )
445- ind = [
446- slice (None , None , 1 ),
447- ] * len (s )
468+ ind = [slice (None , None , 1 )] * len (s )
448469 for ii in range (a .shape [la ]):
449470 ind [la ] = ii
450471 tind = tuple (ind )
451472 a_inp = a [tind ]
452- res = out [tind ] if out is not None else None
453- a_res = _c2c_fftnd_impl (
454- a_inp , s = ss , axes = aa , overwrite_x = True , direction = 1 , out = res
455- )
456- if a_res is not a_inp :
457- a [tind ] = a_res # copy in place
473+ res = out [tind ] if out is not None else a_inp
474+ _ = _c2c_fftnd_impl (a_inp , s = ss , axes = aa , direction = 1 , out = res )
475+ if out is not None :
476+ a = out
458477 else :
478+ # another size_changes check is needed if there are repeated axes
479+ # of last axis, since since FFT changes the shape along last axis
480+ size_changes = [
481+ axis for axis , n in axes_and_s if a .shape [axis ] != n
482+ ]
483+
459484 # a series of 1D c2c FFTs along all axes except last
460- for ii in range (len (axes ) - 2 , - 1 , - 1 ):
461- a = _c2c_fft1d_impl (a , s [ii ], axes [ii ], overwrite_x = True )
485+ for axis , n in axes_and_s :
486+ if axis in size_changes :
487+ if axis == size_changes [- 1 ]:
488+ res = out
489+ elif res is not None and n < res .shape [axis ]:
490+ res = res [(slice (None ),) * axis + (slice (n ),)]
491+ else :
492+ res = None
493+ a = _c2c_fft1d_impl (a , n , axis , out = res )
494+ res = a
462495 return a
463496
464497
465498def _c2r_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
466499 a = np .asarray (x )
467500 no_trim = (s is None ) and (axes is None )
468501 s , axes = _cook_nd_args (a , s , axes , invreal = True )
502+ axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
469503 la = axes [- 1 ]
470504 if not no_trim :
471505 a = _trim_array (a , s , axes )
472506 if len (s ) > 1 :
473- if not no_trim :
474- a = _pad_array (a , s , axes )
475- ovr_x = True if _datacopied (a , x ) else False
476507 len_axes = len (axes )
477508 if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
509+ if not no_trim :
510+ a = _pad_array (a , s , axes )
478511 # a series of ND c2c FFTs along last axis
479512 # due to need to write into a, we must copy
480- if not ovr_x :
481- a = a .copy ()
482- ovr_x = True
513+ a = a if _datacopied (a , x ) else a .copy ()
483514 if not np .issubdtype (a .dtype , np .complexfloating ):
484515 # complex output will be copied to input, copy is needed
485516 if a .dtype == np .float32 :
486517 a = a .astype (np .complex64 )
487518 else :
488519 a = a .astype (np .complex128 )
489- ovr_x = True
490520 ss , aa = _remove_axis (s , axes , - 1 )
491- ind = [
492- slice (None , None , 1 ),
493- ] * len (s )
521+ ind = [slice (None , None , 1 )] * len (s )
494522 for ii in range (a .shape [la ]):
495523 ind [la ] = ii
496524 tind = tuple (ind )
497525 a_inp = a [tind ]
498526 # out has real dtype and cannot be used in intermediate steps
499- a_res = _c2c_fftnd_impl (
500- a_inp , s = ss , axes = aa , overwrite_x = True , direction = - 1
527+ # ss and aa are reversed since np.irfftn uses forward order but
528+ # np.ifftn uses reverse order see numpy-gh-28950
529+ _ = _c2c_fftnd_impl (
530+ a_inp , s = ss [::- 1 ], axes = aa [::- 1 ], out = a_inp , direction = - 1
501531 )
502- if a_res is not a_inp :
503- a [tind ] = a_res # copy in place
504532 else :
505533 # a series of 1D c2c FFTs along all axes except last
506- for ii in range (len (axes ) - 1 ):
507- # out has real dtype and cannot be used in intermediate steps
508- a = _c2c_fft1d_impl (
509- a , s [ii ], axes [ii ], overwrite_x = ovr_x , direction = - 1
510- )
511- ovr_x = True
534+ # forward order, see numpy-gh-28950
535+ axes_and_s = list (zip (axes , s ))[:- 1 ]
536+ size_changes = [
537+ axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n
538+ ]
539+ # out has real dtype cannot be used for intermediate steps
540+ res = None
541+ for axis , n in axes_and_s :
542+ if axis in size_changes :
543+ if res is not None and n < res .shape [axis ]:
544+ # pylint: disable=unsubscriptable-object
545+ res = res [(slice (None ),) * axis + (slice (n ),)]
546+ else :
547+ res = None
548+ a = _c2c_fft1d_impl (a , n , axis , out = res , direction = - 1 )
549+ res = a
512550 # c2r along last axis
513551 a = _c2r_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = out )
514552 return a
0 commit comments