|
29 | 29 | one_hot, |
30 | 30 | pad, |
31 | 31 | partition, |
| 32 | + searchsorted, |
32 | 33 | setdiff1d, |
33 | 34 | sinc, |
34 | 35 | ) |
35 | 36 | from array_api_extra._lib._backends import NUMPY_VERSION, Backend |
36 | 37 | from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal |
| 38 | +from array_api_extra._lib._utils._compat import ( |
| 39 | + array_namespace, |
| 40 | + is_jax_namespace, |
| 41 | + is_torch_namespace, |
| 42 | +) |
37 | 43 | from array_api_extra._lib._utils._compat import device as get_device |
38 | | -from array_api_extra._lib._utils._compat import is_jax_namespace |
39 | 44 | from array_api_extra._lib._utils._helpers import eager_shape, ndindex |
40 | 45 | from array_api_extra._lib._utils._typing import Array, Device |
41 | 46 | from array_api_extra.testing import lazy_xp_function |
|
52 | 57 | lazy_xp_function(pad) |
53 | 58 | # FIXME calls in1d which calls xp.unique_values without size |
54 | 59 | lazy_xp_function(setdiff1d, jax_jit=False) |
| 60 | +lazy_xp_function(searchsorted) |
55 | 61 | lazy_xp_function(sinc) |
56 | 62 |
|
57 | 63 | NestedFloatList = list[float] | list["NestedFloatList"] |
@@ -1637,3 +1643,157 @@ def test_kind(self, xp: ModuleType, library: Backend): |
1637 | 1643 | expected = xp.asarray([False, True, False, True]) |
1638 | 1644 | res = isin(a, b, kind="sort") |
1639 | 1645 | xp_assert_equal(res, expected) |
| 1646 | + |
| 1647 | + |
| 1648 | +def _apply_over_batch(*argdefs: tuple[str, int]): |
| 1649 | + """ |
| 1650 | + Factory for decorator that applies a function over batched arguments. |
| 1651 | +
|
| 1652 | + Array arguments may have any number of core dimensions (typically 0, |
| 1653 | + 1, or 2) and any broadcastable batch shapes. There may be any |
| 1654 | + number of array outputs of any number of dimensions. Assumptions |
| 1655 | + right now - which are satisfied by all functions of interest in `linalg` - |
| 1656 | + are that all array inputs are consecutive keyword or positional arguments, |
| 1657 | + and that the wrapped function returns either a single array or a tuple of |
| 1658 | + arrays. It's only as general as it needs to be right now - it can be extended. |
| 1659 | +
|
| 1660 | + Parameters |
| 1661 | + ---------- |
| 1662 | + *argdefs : tuple of (str, int) |
| 1663 | + Definitions of array arguments: the keyword name of the argument, and |
| 1664 | + the number of core dimensions. |
| 1665 | +
|
| 1666 | + Example: |
| 1667 | + -------- |
| 1668 | + `linalg.eig` accepts two matrices as the first two arguments `a` and `b`, where |
| 1669 | + `b` is optional, and returns one array or a tuple of arrays, depending on the |
| 1670 | + values of other positional or keyword arguments. To generate a wrapper that applies |
| 1671 | + the function over batches of `a` and optionally `b` : |
| 1672 | +
|
| 1673 | + >>> _apply_over_batch(('a', 2), ('b', 2)) |
| 1674 | + """ |
| 1675 | + names, ndims = list(zip(*argdefs, strict=True)) |
| 1676 | + n_arrays = len(names) |
| 1677 | + |
| 1678 | + def decorator(f): |
| 1679 | + def wrapper(*args_tuple, **kwargs): |
| 1680 | + args = list(args_tuple) |
| 1681 | + |
| 1682 | + # Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs` |
| 1683 | + arrays, other_args = args[:n_arrays], args[n_arrays:] |
| 1684 | + for i, name in enumerate(names): |
| 1685 | + if name in kwargs: |
| 1686 | + if i + 1 <= len(args): |
| 1687 | + message = ( |
| 1688 | + f"{f.__name__}() got multiple values for argument `{name}`." |
| 1689 | + ) |
| 1690 | + raise ValueError(message) |
| 1691 | + arrays.append(kwargs.pop(name)) |
| 1692 | + |
| 1693 | + xp = array_namespace(*arrays) |
| 1694 | + |
| 1695 | + # Determine core and batch shapes |
| 1696 | + batch_shapes = [] |
| 1697 | + core_shapes = [] |
| 1698 | + for i, (array, ndim) in enumerate(zip(arrays, ndims, strict=True)): |
| 1699 | + array = None if array is None else xp.asarray(array) # noqa: PLW2901 |
| 1700 | + shape = () if array is None else array.shape |
| 1701 | + arrays[i] = array |
| 1702 | + batch_shapes.append(shape[:-ndim] if ndim > 0 else shape) |
| 1703 | + core_shapes.append(shape[-ndim:] if ndim > 0 else ()) |
| 1704 | + |
| 1705 | + # Early exit if call is not batched |
| 1706 | + if not any(batch_shapes): |
| 1707 | + return f(*arrays, *other_args, **kwargs) |
| 1708 | + |
| 1709 | + # Determine broadcasted batch shape |
| 1710 | + batch_shape = np.broadcast_shapes(*batch_shapes) # Gives OK error message |
| 1711 | + |
| 1712 | + # Broadcast arrays to appropriate shape |
| 1713 | + for i, (array, core_shape) in enumerate( |
| 1714 | + zip(arrays, core_shapes, strict=True) |
| 1715 | + ): |
| 1716 | + if array is None: |
| 1717 | + continue |
| 1718 | + arrays[i] = xp.broadcast_to(array, batch_shape + core_shape) |
| 1719 | + |
| 1720 | + # Main loop |
| 1721 | + results = [] |
| 1722 | + for index in np.ndindex(batch_shape): |
| 1723 | + result = f( |
| 1724 | + *( |
| 1725 | + (array[index] if array is not None else None) |
| 1726 | + for array in arrays |
| 1727 | + ), |
| 1728 | + *other_args, |
| 1729 | + **kwargs, |
| 1730 | + ) |
| 1731 | + # Assume `result` is either a tuple or single array. This is easily |
| 1732 | + # generalized by allowing the contributor to pass an `unpack_result` |
| 1733 | + # callable to the decorator factory. |
| 1734 | + result = (result,) if not isinstance(result, tuple) else result |
| 1735 | + results.append(result) |
| 1736 | + results = list(zip(*results, strict=True)) |
| 1737 | + |
| 1738 | + # Reshape results |
| 1739 | + for i, result in enumerate(results): |
| 1740 | + result = xp.stack(result) # noqa: PLW2901 |
| 1741 | + core_shape = result.shape[1:] |
| 1742 | + results[i] = xp.reshape(result, batch_shape + core_shape) |
| 1743 | + |
| 1744 | + # Assume `result` should be a single array if there is only one element or |
| 1745 | + # a `tuple` otherwise. This is easily generalized by allowing the |
| 1746 | + # contributor to pass an `pack_result` callable to the decorator factory. |
| 1747 | + return results[0] if len(results) == 1 else results |
| 1748 | + |
| 1749 | + return wrapper |
| 1750 | + |
| 1751 | + return decorator |
| 1752 | + |
| 1753 | + |
| 1754 | +@_apply_over_batch(("a", 1), ("v", 1)) |
| 1755 | +def xp_searchsorted(a, v, side, xp): |
| 1756 | + return xp.searchsorted(xp.asarray(a), xp.asarray(v), side=side) |
| 1757 | + |
| 1758 | + |
| 1759 | +@pytest.mark.skip_xp_backend(Backend.DASK, reason="no take_along_axis") |
| 1760 | +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no searchsorted") |
| 1761 | +class TestSearchsorted: |
| 1762 | + @pytest.mark.parametrize("side", ["left", "right"]) |
| 1763 | + @pytest.mark.parametrize("ties", [False, True]) |
| 1764 | + @pytest.mark.parametrize( |
| 1765 | + "shape", [0, 1, 2, 10, 11, 1000, 10001, (2, 0), (0, 2), (2, 10), (2, 3, 11)] |
| 1766 | + ) |
| 1767 | + @pytest.mark.parametrize("nans_x", [False, True]) |
| 1768 | + @pytest.mark.parametrize("infs_x", [False, True]) |
| 1769 | + def test_nd(self, side, ties, shape, nans_x, infs_x, xp): |
| 1770 | + if nans_x and is_torch_namespace(xp): |
| 1771 | + pytest.skip("torch sorts NaNs differently") |
| 1772 | + rng = np.random.default_rng(945298725498274853) |
| 1773 | + x = rng.integers(5, size=shape) if ties else rng.random(shape) |
| 1774 | + # float32 is to accommodate JAX - nextafter with `float64` is too small? |
| 1775 | + x = np.asarray(x, dtype=np.float32) |
| 1776 | + xr = np.nextafter(x, np.inf) |
| 1777 | + xl = np.nextafter(x, -np.inf) |
| 1778 | + x_ = np.asarray([-np.inf, np.inf, np.nan]) |
| 1779 | + x_ = np.broadcast_to(x_, (*x.shape[:-1], 3)) |
| 1780 | + y = rng.permuted(np.concatenate((xl, x, xr, x_), axis=-1), axis=-1) |
| 1781 | + if nans_x: |
| 1782 | + mask = rng.random(shape) < 0.1 |
| 1783 | + x[mask] = np.nan |
| 1784 | + if infs_x: |
| 1785 | + mask = rng.random(shape) < 0.1 |
| 1786 | + x[mask] = -np.inf |
| 1787 | + mask = rng.random(shape) > 0.9 |
| 1788 | + x[mask] = np.inf |
| 1789 | + x = np.sort(x, stable=True, axis=-1) |
| 1790 | + x, y = np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64) |
| 1791 | + xp_default_int = xp.asarray(1).dtype |
| 1792 | + if x.size == 0 and x.ndim > 0 and x.shape[-1] != 0: |
| 1793 | + ref = xp.empty((*x.shape[:-1], y.shape[-1]), dtype=xp_default_int) |
| 1794 | + else: |
| 1795 | + ref = xp_searchsorted(x, y, side=side, xp=np) |
| 1796 | + ref = xp.asarray(ref, dtype=xp_default_int) |
| 1797 | + x, y = xp.asarray(x.copy()), xp.asarray(y.copy()) |
| 1798 | + res = searchsorted(x, y, side=side, xp=xp) |
| 1799 | + xp_assert_equal(res, ref) |
0 commit comments