|
74 | 74 | import warnings |
75 | 75 |
|
76 | 76 | import numpy as np |
77 | | -from numpy import array, conjugate, prod, sqrt, take |
78 | 77 |
|
79 | | -from . import _float_utils |
80 | 78 | from . import _pydfti as mkl_fft # pylint: disable=no-name-in-module |
| 79 | +from ._fft_utils import _check_norm, _compute_fwd_scale |
| 80 | +from ._float_utils import __downcast_float128_array |
81 | 81 |
|
82 | 82 |
|
83 | | -def _compute_fwd_scale(norm, n, shape): |
84 | | - _check_norm(norm) |
85 | | - if norm in (None, "backward"): |
86 | | - return 1.0 |
87 | | - |
88 | | - ss = n if n is not None else shape |
89 | | - nn = prod(ss) |
90 | | - fsc = 1 / nn if nn != 0 else 1 |
91 | | - if norm == "forward": |
92 | | - return fsc |
93 | | - else: # norm == "ortho" |
94 | | - return sqrt(fsc) |
95 | | - |
96 | | - |
97 | | -def _check_norm(norm): |
98 | | - if norm not in (None, "ortho", "forward", "backward"): |
99 | | - raise ValueError( |
100 | | - f"Invalid norm value {norm} should be None, 'ortho', 'forward', " |
101 | | - "or 'backward'." |
| 83 | +# copied with modifications from: |
| 84 | +# https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py |
| 85 | +def _cook_nd_args(a, s=None, axes=None, invreal=False): |
| 86 | + if s is None: |
| 87 | + shapeless = True |
| 88 | + if axes is None: |
| 89 | + s = list(a.shape) |
| 90 | + else: |
| 91 | + s = np.take(a.shape, axes) |
| 92 | + else: |
| 93 | + shapeless = False |
| 94 | + s = list(s) |
| 95 | + if axes is None: |
| 96 | + if not shapeless and np.__version__ >= "2.0": |
| 97 | + msg = ( |
| 98 | + "`axes` should not be `None` if `s` is not `None` " |
| 99 | + "(Deprecated in NumPy 2.0). In a future version of NumPy, " |
| 100 | + "this will raise an error and `s[i]` will correspond to " |
| 101 | + "the size along the transformed axis specified by " |
| 102 | + "`axes[i]`. To retain current behaviour, pass a sequence " |
| 103 | + "[0, ..., k-1] to `axes` for an array of dimension k." |
| 104 | + ) |
| 105 | + warnings.warn(msg, DeprecationWarning, stacklevel=3) |
| 106 | + axes = list(range(-len(s), 0)) |
| 107 | + if len(s) != len(axes): |
| 108 | + raise ValueError("Shape and axes have different lengths.") |
| 109 | + if invreal and shapeless: |
| 110 | + s[-1] = (a.shape[axes[-1]] - 1) * 2 |
| 111 | + if None in s and np.__version__ >= "2.0": |
| 112 | + msg = ( |
| 113 | + "Passing an array containing `None` values to `s` is " |
| 114 | + "deprecated in NumPy 2.0 and will raise an error in " |
| 115 | + "a future version of NumPy. To use the default behaviour " |
| 116 | + "of the corresponding 1-D transform, pass the value matching " |
| 117 | + "the default for its `n` parameter. To use the default " |
| 118 | + "behaviour for every axis, the `s` argument can be omitted." |
102 | 119 | ) |
| 120 | + warnings.warn(msg, DeprecationWarning, stacklevel=3) |
| 121 | + # use the whole input array along axis `i` if `s[i] == -1 or None` |
| 122 | + s = [a.shape[_a] if _s in [-1, None] else _s for _s, _a in zip(s, axes)] |
| 123 | + |
| 124 | + return s, axes |
103 | 125 |
|
104 | 126 |
|
105 | 127 | def _swap_direction(norm): |
@@ -218,7 +240,7 @@ def fft(a, n=None, axis=-1, norm=None): |
218 | 240 |
|
219 | 241 | """ |
220 | 242 |
|
221 | | - x = _float_utils.__downcast_float128_array(a) |
| 243 | + x = __downcast_float128_array(a) |
222 | 244 | fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
223 | 245 |
|
224 | 246 | return trycall(mkl_fft.fft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}) |
@@ -311,7 +333,7 @@ def ifft(a, n=None, axis=-1, norm=None): |
311 | 333 |
|
312 | 334 | """ |
313 | 335 |
|
314 | | - x = _float_utils.__downcast_float128_array(a) |
| 336 | + x = __downcast_float128_array(a) |
315 | 337 | fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
316 | 338 |
|
317 | 339 | return trycall(mkl_fft.ifft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}) |
@@ -402,7 +424,7 @@ def rfft(a, n=None, axis=-1, norm=None): |
402 | 424 |
|
403 | 425 | """ |
404 | 426 |
|
405 | | - x = _float_utils.__downcast_float128_array(a) |
| 427 | + x = __downcast_float128_array(a) |
406 | 428 | fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
407 | 429 |
|
408 | 430 | return trycall(mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}) |
@@ -495,7 +517,7 @@ def irfft(a, n=None, axis=-1, norm=None): |
495 | 517 |
|
496 | 518 | """ |
497 | 519 |
|
498 | | - x = _float_utils.__downcast_float128_array(a) |
| 520 | + x = __downcast_float128_array(a) |
499 | 521 | fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) |
500 | 522 |
|
501 | 523 | return trycall( |
@@ -581,9 +603,9 @@ def hfft(a, n=None, axis=-1, norm=None): |
581 | 603 | """ |
582 | 604 |
|
583 | 605 | norm = _swap_direction(norm) |
584 | | - x = _float_utils.__downcast_float128_array(a) |
585 | | - x = array(x, copy=True, dtype=complex) |
586 | | - conjugate(x, out=x) |
| 606 | + x = __downcast_float128_array(a) |
| 607 | + x = np.array(x, copy=True, dtype=complex) |
| 608 | + np.conjugate(x, out=x) |
587 | 609 | fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) |
588 | 610 |
|
589 | 611 | return trycall( |
@@ -651,61 +673,18 @@ def ihfft(a, n=None, axis=-1, norm=None): |
651 | 673 |
|
652 | 674 | # The copy may be required for multithreading. |
653 | 675 | norm = _swap_direction(norm) |
654 | | - x = _float_utils.__downcast_float128_array(a) |
655 | | - x = array(x, copy=True, dtype=float) |
| 676 | + x = __downcast_float128_array(a) |
| 677 | + x = np.array(x, copy=True, dtype=float) |
656 | 678 | fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
657 | 679 |
|
658 | 680 | output = trycall( |
659 | 681 | mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc} |
660 | 682 | ) |
661 | 683 |
|
662 | | - conjugate(output, out=output) |
| 684 | + np.conjugate(output, out=output) |
663 | 685 | return output |
664 | 686 |
|
665 | 687 |
|
666 | | -# copied from: https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py |
667 | | -def _cook_nd_args(a, s=None, axes=None, invreal=False): |
668 | | - if s is None: |
669 | | - shapeless = True |
670 | | - if axes is None: |
671 | | - s = list(a.shape) |
672 | | - else: |
673 | | - s = take(a.shape, axes) |
674 | | - else: |
675 | | - shapeless = False |
676 | | - s = list(s) |
677 | | - if axes is None: |
678 | | - if not shapeless and np.__version__ >= "2.0": |
679 | | - msg = ( |
680 | | - "`axes` should not be `None` if `s` is not `None` " |
681 | | - "(Deprecated in NumPy 2.0). In a future version of NumPy, " |
682 | | - "this will raise an error and `s[i]` will correspond to " |
683 | | - "the size along the transformed axis specified by " |
684 | | - "`axes[i]`. To retain current behaviour, pass a sequence " |
685 | | - "[0, ..., k-1] to `axes` for an array of dimension k." |
686 | | - ) |
687 | | - warnings.warn(msg, DeprecationWarning, stacklevel=3) |
688 | | - axes = list(range(-len(s), 0)) |
689 | | - if len(s) != len(axes): |
690 | | - raise ValueError("Shape and axes have different lengths.") |
691 | | - if invreal and shapeless: |
692 | | - s[-1] = (a.shape[axes[-1]] - 1) * 2 |
693 | | - if None in s and np.__version__ >= "2.0": |
694 | | - msg = ( |
695 | | - "Passing an array containing `None` values to `s` is " |
696 | | - "deprecated in NumPy 2.0 and will raise an error in " |
697 | | - "a future version of NumPy. To use the default behaviour " |
698 | | - "of the corresponding 1-D transform, pass the value matching " |
699 | | - "the default for its `n` parameter. To use the default " |
700 | | - "behaviour for every axis, the `s` argument can be omitted." |
701 | | - ) |
702 | | - warnings.warn(msg, DeprecationWarning, stacklevel=3) |
703 | | - # use the whole input array along axis `i` if `s[i] == -1 or None` |
704 | | - s = [a.shape[_a] if _s in [-1, None] else _s for _s, _a in zip(s, axes)] |
705 | | - |
706 | | - return s, axes |
707 | | - |
708 | | - |
709 | 688 | def fftn(a, s=None, axes=None, norm=None): |
710 | 689 | """ |
711 | 690 | Compute the N-dimensional discrete Fourier Transform. |
@@ -806,7 +785,7 @@ def fftn(a, s=None, axes=None, norm=None): |
806 | 785 |
|
807 | 786 | """ |
808 | 787 |
|
809 | | - x = _float_utils.__downcast_float128_array(a) |
| 788 | + x = __downcast_float128_array(a) |
810 | 789 | s, axes = _cook_nd_args(x, s, axes) |
811 | 790 | fsc = _compute_fwd_scale(norm, s, x.shape) |
812 | 791 |
|
@@ -913,7 +892,7 @@ def ifftn(a, s=None, axes=None, norm=None): |
913 | 892 |
|
914 | 893 | """ |
915 | 894 |
|
916 | | - x = _float_utils.__downcast_float128_array(a) |
| 895 | + x = __downcast_float128_array(a) |
917 | 896 | s, axes = _cook_nd_args(x, s, axes) |
918 | 897 | fsc = _compute_fwd_scale(norm, s, x.shape) |
919 | 898 |
|
@@ -1201,7 +1180,7 @@ def rfftn(a, s=None, axes=None, norm=None): |
1201 | 1180 |
|
1202 | 1181 | """ |
1203 | 1182 |
|
1204 | | - x = _float_utils.__downcast_float128_array(a) |
| 1183 | + x = __downcast_float128_array(a) |
1205 | 1184 | s, axes = _cook_nd_args(x, s, axes) |
1206 | 1185 | fsc = _compute_fwd_scale(norm, s, x.shape) |
1207 | 1186 |
|
@@ -1345,7 +1324,7 @@ def irfftn(a, s=None, axes=None, norm=None): |
1345 | 1324 |
|
1346 | 1325 | """ |
1347 | 1326 |
|
1348 | | - x = _float_utils.__downcast_float128_array(a) |
| 1327 | + x = __downcast_float128_array(a) |
1349 | 1328 | s, axes = _cook_nd_args(x, s, axes, invreal=True) |
1350 | 1329 | fsc = _compute_fwd_scale(norm, s, x.shape) |
1351 | 1330 |
|
|
0 commit comments