|
74 | 74 | import warnings |
75 | 75 |
|
76 | 76 | import numpy as np |
77 | | -from numpy import array, asanyarray, conjugate, prod, sqrt, take |
| 77 | +from numpy import array, conjugate, prod, sqrt, take |
78 | 78 |
|
79 | 79 | from . import _float_utils |
80 | 80 | from . import _pydfti as mkl_fft # pylint: disable=no-name-in-module |
81 | 81 |
|
82 | 82 |
|
| 83 | +def _compute_fwd_scale(norm, n, shape): |
| 84 | + _check_norm(norm) |
| 85 | + if norm in (None, "backward"): |
| 86 | + return 1.0 |
| 87 | + |
| 88 | + ss = n if n is not None else shape |
| 89 | + nn = prod(ss) |
| 90 | + fsc = 1 / nn if nn != 0 else 1 |
| 91 | + if norm == "forward": |
| 92 | + return fsc |
| 93 | + else: # norm == "ortho" |
| 94 | + return sqrt(fsc) |
| 95 | + |
| 96 | + |
83 | 97 | def _check_norm(norm): |
84 | 98 | if norm not in (None, "ortho", "forward", "backward"): |
85 | 99 | raise ValueError( |
86 | | - ( |
87 | | - "Invalid norm value {} should be None, " |
88 | | - '"ortho", "forward", or "backward".' |
89 | | - ).format(norm) |
| 100 | + f"Invalid norm value {norm} should be None, 'ortho', 'forward', " |
| 101 | + "or 'backward'." |
90 | 102 | ) |
91 | 103 |
|
92 | 104 |
|
93 | | -def frwd_sc_1d(n, s): |
94 | | - nn = n if n is not None else s |
95 | | - return 1 / nn if nn != 0 else 1 |
96 | | - |
97 | | - |
98 | | -def frwd_sc_nd(s, x_shape): |
99 | | - ss = s if s is not None else x_shape |
100 | | - nn = prod(ss) |
101 | | - return 1 / nn if nn != 0 else 1 |
102 | | - |
| 105 | +def _swap_direction(norm): |
| 106 | + _check_norm(norm) |
| 107 | + _swap_direction_map = { |
| 108 | + "backward": "forward", |
| 109 | + None: "forward", |
| 110 | + "ortho": "ortho", |
| 111 | + "forward": "backward", |
| 112 | + } |
103 | 113 |
|
104 | | -def ortho_sc_1d(n, s): |
105 | | - return sqrt(frwd_sc_1d(n, s)) |
| 114 | + return _swap_direction_map[norm] |
106 | 115 |
|
107 | 116 |
|
108 | 117 | def trycall(func, args, kwrds): |
@@ -208,15 +217,9 @@ def fft(a, n=None, axis=-1, norm=None): |
208 | 217 | the `numpy.fft` documentation. |
209 | 218 |
|
210 | 219 | """ |
211 | | - _check_norm(norm) |
212 | | - x = _float_utils.__downcast_float128_array(a) |
213 | 220 |
|
214 | | - if norm in (None, "backward"): |
215 | | - fsc = 1.0 |
216 | | - elif norm == "forward": |
217 | | - fsc = frwd_sc_1d(n, x.shape[axis]) |
218 | | - else: |
219 | | - fsc = ortho_sc_1d(n, x.shape[axis]) |
| 221 | + x = _float_utils.__downcast_float128_array(a) |
| 222 | + fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
220 | 223 |
|
221 | 224 | return trycall(mkl_fft.fft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}) |
222 | 225 |
|
@@ -307,15 +310,9 @@ def ifft(a, n=None, axis=-1, norm=None): |
307 | 310 | >>> plt.show() |
308 | 311 |
|
309 | 312 | """ |
310 | | - _check_norm(norm) |
311 | | - x = _float_utils.__downcast_float128_array(a) |
312 | 313 |
|
313 | | - if norm in (None, "backward"): |
314 | | - fsc = 1.0 |
315 | | - elif norm == "forward": |
316 | | - fsc = frwd_sc_1d(n, x.shape[axis]) |
317 | | - else: |
318 | | - fsc = ortho_sc_1d(n, x.shape[axis]) |
| 314 | + x = _float_utils.__downcast_float128_array(a) |
| 315 | + fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
319 | 316 |
|
320 | 317 | return trycall(mkl_fft.ifft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}) |
321 | 318 |
|
@@ -404,15 +401,9 @@ def rfft(a, n=None, axis=-1, norm=None): |
404 | 401 | exploited to compute only the non-negative frequency terms. |
405 | 402 |
|
406 | 403 | """ |
407 | | - _check_norm(norm) |
408 | | - x = _float_utils.__downcast_float128_array(a) |
409 | 404 |
|
410 | | - if norm in (None, "backward"): |
411 | | - fsc = 1.0 |
412 | | - elif norm == "forward": |
413 | | - fsc = frwd_sc_1d(n, x.shape[axis]) |
414 | | - else: |
415 | | - fsc = ortho_sc_1d(n, x.shape[axis]) |
| 405 | + x = _float_utils.__downcast_float128_array(a) |
| 406 | + fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
416 | 407 |
|
417 | 408 | return trycall(mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}) |
418 | 409 |
|
@@ -503,16 +494,9 @@ def irfft(a, n=None, axis=-1, norm=None): |
503 | 494 | specified, and the output array is purely real. |
504 | 495 |
|
505 | 496 | """ |
506 | | - _check_norm(norm) |
507 | | - x = _float_utils.__downcast_float128_array(a) |
508 | 497 |
|
509 | | - nn = n if n else 2 * (x.shape[axis] - 1) |
510 | | - if norm in (None, "backward"): |
511 | | - fsc = 1.0 |
512 | | - elif norm == "forward": |
513 | | - fsc = frwd_sc_1d(nn, nn) |
514 | | - else: |
515 | | - fsc = ortho_sc_1d(nn, nn) |
| 498 | + x = _float_utils.__downcast_float128_array(a) |
| 499 | + fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) |
516 | 500 |
|
517 | 501 | return trycall( |
518 | 502 | mkl_fft.irfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc} |
@@ -595,18 +579,12 @@ def hfft(a, n=None, axis=-1, norm=None): |
595 | 579 | [ 2., -2.]]) |
596 | 580 |
|
597 | 581 | """ |
598 | | - _check_norm(norm) |
| 582 | + |
| 583 | + norm = _swap_direction(norm) |
599 | 584 | x = _float_utils.__downcast_float128_array(a) |
600 | 585 | x = array(x, copy=True, dtype=complex) |
601 | 586 | conjugate(x, out=x) |
602 | | - |
603 | | - nn = n if n else 2 * (x.shape[axis] - 1) |
604 | | - if norm in (None, "backward"): |
605 | | - fsc = frwd_sc_1d(nn, nn) |
606 | | - elif norm == "forward": |
607 | | - fsc = 1.0 |
608 | | - else: |
609 | | - fsc = ortho_sc_1d(nn, nn) |
| 587 | + fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) |
610 | 588 |
|
611 | 589 | return trycall( |
612 | 590 | mkl_fft.irfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc} |
@@ -670,17 +648,12 @@ def ihfft(a, n=None, axis=-1, norm=None): |
670 | 648 | array([ 1.-0.j, 2.-0.j, 3.-0.j, 4.-0.j]) |
671 | 649 |
|
672 | 650 | """ |
| 651 | + |
673 | 652 | # The copy may be required for multithreading. |
674 | | - _check_norm(norm) |
| 653 | + norm = _swap_direction(norm) |
675 | 654 | x = _float_utils.__downcast_float128_array(a) |
676 | 655 | x = array(x, copy=True, dtype=float) |
677 | | - |
678 | | - if norm in (None, "backward"): |
679 | | - fsc = frwd_sc_1d(n, x.shape[axis]) |
680 | | - elif norm == "forward": |
681 | | - fsc = 1.0 |
682 | | - else: |
683 | | - fsc = ortho_sc_1d(n, x.shape[axis]) |
| 656 | + fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
684 | 657 |
|
685 | 658 | output = trycall( |
686 | 659 | mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc} |
@@ -832,16 +805,10 @@ def fftn(a, s=None, axes=None, norm=None): |
832 | 805 | >>> plt.show() |
833 | 806 |
|
834 | 807 | """ |
835 | | - _check_norm(norm) |
| 808 | + |
836 | 809 | x = _float_utils.__downcast_float128_array(a) |
837 | 810 | s, axes = _cook_nd_args(x, s, axes) |
838 | | - |
839 | | - if norm in (None, "backward"): |
840 | | - fsc = 1.0 |
841 | | - elif norm == "forward": |
842 | | - fsc = frwd_sc_nd(s, x.shape) |
843 | | - else: |
844 | | - fsc = sqrt(frwd_sc_nd(s, x.shape)) |
| 811 | + fsc = _compute_fwd_scale(norm, s, x.shape) |
845 | 812 |
|
846 | 813 | return trycall(mkl_fft.fftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc}) |
847 | 814 |
|
@@ -945,16 +912,10 @@ def ifftn(a, s=None, axes=None, norm=None): |
945 | 912 | >>> plt.show() |
946 | 913 |
|
947 | 914 | """ |
948 | | - _check_norm(norm) |
| 915 | + |
949 | 916 | x = _float_utils.__downcast_float128_array(a) |
950 | 917 | s, axes = _cook_nd_args(x, s, axes) |
951 | | - |
952 | | - if norm in (None, "backward"): |
953 | | - fsc = 1.0 |
954 | | - elif norm == "forward": |
955 | | - fsc = frwd_sc_nd(s, x.shape) |
956 | | - else: |
957 | | - fsc = sqrt(frwd_sc_nd(s, x.shape)) |
| 918 | + fsc = _compute_fwd_scale(norm, s, x.shape) |
958 | 919 |
|
959 | 920 | return trycall( |
960 | 921 | mkl_fft.ifftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc} |
@@ -1053,9 +1014,8 @@ def fft2(a, s=None, axes=(-2, -1), norm=None): |
1053 | 1014 | 0.0 +0.j , 0.0 +0.j ]]) |
1054 | 1015 |
|
1055 | 1016 | """ |
1056 | | - _check_norm(norm) |
1057 | | - x = _float_utils.__downcast_float128_array(a) |
1058 | | - return fftn(x, s=s, axes=axes, norm=norm) |
| 1017 | + |
| 1018 | + return fftn(a, s=s, axes=axes, norm=norm) |
1059 | 1019 |
|
1060 | 1020 |
|
1061 | 1021 | def ifft2(a, s=None, axes=(-2, -1), norm=None): |
@@ -1147,9 +1107,8 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None): |
1147 | 1107 | [ 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]]) |
1148 | 1108 |
|
1149 | 1109 | """ |
1150 | | - _check_norm(norm) |
1151 | | - x = _float_utils.__downcast_float128_array(a) |
1152 | | - return ifftn(x, s=s, axes=axes, norm=norm) |
| 1110 | + |
| 1111 | + return ifftn(a, s=s, axes=axes, norm=norm) |
1153 | 1112 |
|
1154 | 1113 |
|
1155 | 1114 | def rfftn(a, s=None, axes=None, norm=None): |
@@ -1241,18 +1200,10 @@ def rfftn(a, s=None, axes=None, norm=None): |
1241 | 1200 | [ 0.+0.j, 0.+0.j]]]) |
1242 | 1201 |
|
1243 | 1202 | """ |
1244 | | - _check_norm(norm) |
| 1203 | + |
1245 | 1204 | x = _float_utils.__downcast_float128_array(a) |
1246 | 1205 | s, axes = _cook_nd_args(x, s, axes) |
1247 | | - |
1248 | | - if norm in (None, "backward"): |
1249 | | - fsc = 1.0 |
1250 | | - elif norm == "forward": |
1251 | | - x = asanyarray(x) |
1252 | | - fsc = frwd_sc_nd(s, x.shape) |
1253 | | - else: |
1254 | | - x = asanyarray(x) |
1255 | | - fsc = sqrt(frwd_sc_nd(s, x.shape)) |
| 1206 | + fsc = _compute_fwd_scale(norm, s, x.shape) |
1256 | 1207 |
|
1257 | 1208 | return trycall( |
1258 | 1209 | mkl_fft.rfftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc} |
@@ -1298,9 +1249,8 @@ def rfft2(a, s=None, axes=(-2, -1), norm=None): |
1298 | 1249 | For more details see `rfftn`. |
1299 | 1250 |
|
1300 | 1251 | """ |
1301 | | - _check_norm(norm) |
1302 | | - x = _float_utils.__downcast_float128_array(a) |
1303 | | - return rfftn(x, s, axes, norm) |
| 1252 | + |
| 1253 | + return rfftn(a, s, axes, norm) |
1304 | 1254 |
|
1305 | 1255 |
|
1306 | 1256 | def irfftn(a, s=None, axes=None, norm=None): |
@@ -1394,18 +1344,10 @@ def irfftn(a, s=None, axes=None, norm=None): |
1394 | 1344 | [ 1., 1.]]]) |
1395 | 1345 |
|
1396 | 1346 | """ |
1397 | | - _check_norm(norm) |
| 1347 | + |
1398 | 1348 | x = _float_utils.__downcast_float128_array(a) |
1399 | 1349 | s, axes = _cook_nd_args(x, s, axes, invreal=True) |
1400 | | - |
1401 | | - if norm in (None, "backward"): |
1402 | | - fsc = 1.0 |
1403 | | - elif norm == "forward": |
1404 | | - x = asanyarray(x) |
1405 | | - fsc = frwd_sc_nd(s, x.shape) |
1406 | | - else: |
1407 | | - x = asanyarray(x) |
1408 | | - fsc = sqrt(frwd_sc_nd(s, x.shape)) |
| 1350 | + fsc = _compute_fwd_scale(norm, s, x.shape) |
1409 | 1351 |
|
1410 | 1352 | return trycall( |
1411 | 1353 | mkl_fft.irfftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc} |
@@ -1451,6 +1393,5 @@ def irfft2(a, s=None, axes=(-2, -1), norm=None): |
1451 | 1393 | For more details see `irfftn`. |
1452 | 1394 |
|
1453 | 1395 | """ |
1454 | | - _check_norm(norm) |
1455 | | - x = _float_utils.__downcast_float128_array(a) |
1456 | | - return irfftn(x, s, axes, norm) |
| 1396 | + |
| 1397 | + return irfftn(a, s, axes, norm) |
0 commit comments