Skip to content

Commit 9c811a2

Browse files
committed
Add vectorized searchsorted
1 parent 5520c43 commit 9c811a2

File tree

3 files changed

+259
-2
lines changed

3 files changed

+259
-2
lines changed

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
default_dtype,
2323
kron,
2424
nunique,
25+
searchsorted,
2526
)
2627
from ._lib._lazy import lazy_apply
2728

@@ -48,6 +49,7 @@
4849
"one_hot",
4950
"pad",
5051
"partition",
52+
"searchsorted",
5153
"setdiff1d",
5254
"sinc",
5355
]

src/array_api_extra/_lib/_funcs.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
11-
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
11+
from ._utils._compat import (
12+
array_namespace,
13+
is_dask_namespace,
14+
is_jax_array,
15+
is_torch_namespace,
16+
)
1217
from ._utils._helpers import (
1318
asarrays,
1419
capabilities,
@@ -28,6 +33,7 @@
2833
"kron",
2934
"nunique",
3035
"pad",
36+
"searchsorted",
3137
"setdiff1d",
3238
"sinc",
3339
]
@@ -665,6 +671,95 @@ def pad(
665671
return at(padded, tuple(slices)).set(x)
666672

667673

674+
def searchsorted(
675+
x1: Array,
676+
x2: Array,
677+
/,
678+
*,
679+
side: Literal["left", "right"] = "left",
680+
xp: ModuleType,
681+
) -> Array:
682+
"""
683+
Find indices where elements should be inserted to maintain order.
684+
685+
Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
686+
were inserted before the indices, the resulting array would remain sorted.
687+
688+
Parameters
689+
----------
690+
x1 : Array
691+
Input array. Should have a real-valued data type. Must be sorted in ascending
692+
order along the last axis.
693+
x2 : Array
694+
Array containing search values. Should have a real-valued data type. Must have
695+
the same shape as ``x1`` except along the last axis.
696+
side : {'left', 'right'}, optional
697+
Argument controlling which index is returned if an element of ``x2`` is equal to
698+
one or more elements of ``x1``: ``'left'`` returns the index of the first of
699+
these elements; ``'right'`` returns the next index after the last of these
700+
elements. Default: ``'left'``.
701+
xp : array_namespace, optional
702+
The standard-compatible namespace for the array arguments. Default: infer.
703+
704+
Returns
705+
-------
706+
Array: integer array
707+
An array of indices with the same shape as ``x2``.
708+
709+
Examples
710+
--------
711+
>>> import array_api_strict as xp
712+
>>> import array_api_extra as xpx
713+
>>> x = xp.asarray([11, 12, 13, 13, 14, 15])
714+
>>> xpx.searchsorted(x, xp.asarray([10, 11.5, 14.5, 16]), xp=xp)
715+
Array([0, 1, 5, 6], dtype=array_api_strict.int64)
716+
>>> xpx.searchsorted(x, xp.asarray(13), xp=xp)
717+
Array(2, dtype=array_api_strict.int64)
718+
>>> xpx.searchsorted(x, xp.asarray(13), side='right', xp=xp)
719+
Array(4, dtype=array_api_strict.int64)
720+
721+
`searchsorted` is vectorized along the last axis.
722+
723+
>>> x1 = xp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]])
724+
>>> x2 = xp.asarray([[1.1, 3.3], [6.6, 8.8]])
725+
>>> xpx.searchsorted(x1, x2, xp=xp)
726+
Array([[1, 3],
727+
[2, 4]], dtype=array_api_strict.int64)
728+
"""
729+
xp = array_namespace(x1, x2) if xp is None else xp
730+
xp_default_int = xp.asarray(1).dtype
731+
y_0d = xp.asarray(x2).ndim == 0
732+
x_1d = x1.ndim <= 1
733+
734+
if x_1d or is_torch_namespace(xp):
735+
x2 = xp.reshape(x2, ()) if (y_0d and x_1d) else x2
736+
out = xp.searchsorted(x1, x2, side=side)
737+
return xp.astype(out, xp_default_int, copy=False)
738+
739+
a = xp.full(x2.shape, 0, device=_compat.device(x1))
740+
741+
if x1.shape[-1] == 0:
742+
return a
743+
744+
n = xp.count_nonzero(~xp.isnan(x1), axis=-1, keepdims=True)
745+
b = xp.broadcast_to(n, x2.shape)
746+
747+
compare = xp.less_equal if side == "left" else xp.less
748+
749+
# while xp.any(b - a > 1):
750+
# refactored to for loop with ~log2(n) iterations for JAX JIT
751+
for _ in range(int(math.log2(x1.shape[-1])) + 1): # type: ignore[arg-type]
752+
c = (a + b) // 2
753+
x0 = xp.take_along_axis(x1, c, axis=-1)
754+
j = compare(x2, x0)
755+
b = xp.where(j, c, b)
756+
a = xp.where(j, a, c)
757+
758+
out = xp.where(compare(x2, xp.min(x1, axis=-1, keepdims=True)), 0, b)
759+
out = xp.where(xp.isnan(x2), x1.shape[-1], out) if side == "right" else out
760+
return xp.astype(out, xp_default_int, copy=False)
761+
762+
668763
def setdiff1d(
669764
x1: Array | complex,
670765
x2: Array | complex,

tests/test_funcs.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,18 @@
2929
one_hot,
3030
pad,
3131
partition,
32+
searchsorted,
3233
setdiff1d,
3334
sinc,
3435
)
3536
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3637
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
38+
from array_api_extra._lib._utils._compat import (
39+
array_namespace,
40+
is_jax_namespace,
41+
is_torch_namespace,
42+
)
3743
from array_api_extra._lib._utils._compat import device as get_device
38-
from array_api_extra._lib._utils._compat import is_jax_namespace
3944
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
4045
from array_api_extra._lib._utils._typing import Array, Device
4146
from array_api_extra.testing import lazy_xp_function
@@ -52,6 +57,7 @@
5257
lazy_xp_function(pad)
5358
# FIXME calls in1d which calls xp.unique_values without size
5459
lazy_xp_function(setdiff1d, jax_jit=False)
60+
lazy_xp_function(searchsorted)
5561
lazy_xp_function(sinc)
5662

5763
NestedFloatList = list[float] | list["NestedFloatList"]
@@ -1637,3 +1643,157 @@ def test_kind(self, xp: ModuleType, library: Backend):
16371643
expected = xp.asarray([False, True, False, True])
16381644
res = isin(a, b, kind="sort")
16391645
xp_assert_equal(res, expected)
1646+
1647+
1648+
def _apply_over_batch(*argdefs: tuple[str, int]):
1649+
"""
1650+
Factory for decorator that applies a function over batched arguments.
1651+
1652+
Array arguments may have any number of core dimensions (typically 0,
1653+
1, or 2) and any broadcastable batch shapes. There may be any
1654+
number of array outputs of any number of dimensions. Assumptions
1655+
right now - which are satisfied by all functions of interest in `linalg` -
1656+
are that all array inputs are consecutive keyword or positional arguments,
1657+
and that the wrapped function returns either a single array or a tuple of
1658+
arrays. It's only as general as it needs to be right now - it can be extended.
1659+
1660+
Parameters
1661+
----------
1662+
*argdefs : tuple of (str, int)
1663+
Definitions of array arguments: the keyword name of the argument, and
1664+
the number of core dimensions.
1665+
1666+
Example:
1667+
--------
1668+
`linalg.eig` accepts two matrices as the first two arguments `a` and `b`, where
1669+
`b` is optional, and returns one array or a tuple of arrays, depending on the
1670+
values of other positional or keyword arguments. To generate a wrapper that applies
1671+
the function over batches of `a` and optionally `b` :
1672+
1673+
>>> _apply_over_batch(('a', 2), ('b', 2))
1674+
"""
1675+
names, ndims = list(zip(*argdefs, strict=True))
1676+
n_arrays = len(names)
1677+
1678+
def decorator(f):
1679+
def wrapper(*args_tuple, **kwargs):
1680+
args = list(args_tuple)
1681+
1682+
# Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs`
1683+
arrays, other_args = args[:n_arrays], args[n_arrays:]
1684+
for i, name in enumerate(names):
1685+
if name in kwargs:
1686+
if i + 1 <= len(args):
1687+
message = (
1688+
f"{f.__name__}() got multiple values for argument `{name}`."
1689+
)
1690+
raise ValueError(message)
1691+
arrays.append(kwargs.pop(name))
1692+
1693+
xp = array_namespace(*arrays)
1694+
1695+
# Determine core and batch shapes
1696+
batch_shapes = []
1697+
core_shapes = []
1698+
for i, (array, ndim) in enumerate(zip(arrays, ndims, strict=True)):
1699+
array = None if array is None else xp.asarray(array) # noqa: PLW2901
1700+
shape = () if array is None else array.shape
1701+
arrays[i] = array
1702+
batch_shapes.append(shape[:-ndim] if ndim > 0 else shape)
1703+
core_shapes.append(shape[-ndim:] if ndim > 0 else ())
1704+
1705+
# Early exit if call is not batched
1706+
if not any(batch_shapes):
1707+
return f(*arrays, *other_args, **kwargs)
1708+
1709+
# Determine broadcasted batch shape
1710+
batch_shape = np.broadcast_shapes(*batch_shapes) # Gives OK error message
1711+
1712+
# Broadcast arrays to appropriate shape
1713+
for i, (array, core_shape) in enumerate(
1714+
zip(arrays, core_shapes, strict=True)
1715+
):
1716+
if array is None:
1717+
continue
1718+
arrays[i] = xp.broadcast_to(array, batch_shape + core_shape)
1719+
1720+
# Main loop
1721+
results = []
1722+
for index in np.ndindex(batch_shape):
1723+
result = f(
1724+
*(
1725+
(array[index] if array is not None else None)
1726+
for array in arrays
1727+
),
1728+
*other_args,
1729+
**kwargs,
1730+
)
1731+
# Assume `result` is either a tuple or single array. This is easily
1732+
# generalized by allowing the contributor to pass an `unpack_result`
1733+
# callable to the decorator factory.
1734+
result = (result,) if not isinstance(result, tuple) else result
1735+
results.append(result)
1736+
results = list(zip(*results, strict=True))
1737+
1738+
# Reshape results
1739+
for i, result in enumerate(results):
1740+
result = xp.stack(result) # noqa: PLW2901
1741+
core_shape = result.shape[1:]
1742+
results[i] = xp.reshape(result, batch_shape + core_shape)
1743+
1744+
# Assume `result` should be a single array if there is only one element or
1745+
# a `tuple` otherwise. This is easily generalized by allowing the
1746+
# contributor to pass an `pack_result` callable to the decorator factory.
1747+
return results[0] if len(results) == 1 else results
1748+
1749+
return wrapper
1750+
1751+
return decorator
1752+
1753+
1754+
@_apply_over_batch(("a", 1), ("v", 1))
1755+
def xp_searchsorted(a, v, side, xp):
1756+
return xp.searchsorted(xp.asarray(a), xp.asarray(v), side=side)
1757+
1758+
1759+
@pytest.mark.skip_xp_backend(Backend.DASK, reason="no take_along_axis")
1760+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no searchsorted")
1761+
class TestSearchsorted:
1762+
@pytest.mark.parametrize("side", ["left", "right"])
1763+
@pytest.mark.parametrize("ties", [False, True])
1764+
@pytest.mark.parametrize(
1765+
"shape", [0, 1, 2, 10, 11, 1000, 10001, (2, 0), (0, 2), (2, 10), (2, 3, 11)]
1766+
)
1767+
@pytest.mark.parametrize("nans_x", [False, True])
1768+
@pytest.mark.parametrize("infs_x", [False, True])
1769+
def test_nd(self, side, ties, shape, nans_x, infs_x, xp):
1770+
if nans_x and is_torch_namespace(xp):
1771+
pytest.skip("torch sorts NaNs differently")
1772+
rng = np.random.default_rng(945298725498274853)
1773+
x = rng.integers(5, size=shape) if ties else rng.random(shape)
1774+
# float32 is to accommodate JAX - nextafter with `float64` is too small?
1775+
x = np.asarray(x, dtype=np.float32)
1776+
xr = np.nextafter(x, np.inf)
1777+
xl = np.nextafter(x, -np.inf)
1778+
x_ = np.asarray([-np.inf, np.inf, np.nan])
1779+
x_ = np.broadcast_to(x_, (*x.shape[:-1], 3))
1780+
y = rng.permuted(np.concatenate((xl, x, xr, x_), axis=-1), axis=-1)
1781+
if nans_x:
1782+
mask = rng.random(shape) < 0.1
1783+
x[mask] = np.nan
1784+
if infs_x:
1785+
mask = rng.random(shape) < 0.1
1786+
x[mask] = -np.inf
1787+
mask = rng.random(shape) > 0.9
1788+
x[mask] = np.inf
1789+
x = np.sort(x, stable=True, axis=-1)
1790+
x, y = np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64)
1791+
xp_default_int = xp.asarray(1).dtype
1792+
if x.size == 0 and x.ndim > 0 and x.shape[-1] != 0:
1793+
ref = xp.empty((*x.shape[:-1], y.shape[-1]), dtype=xp_default_int)
1794+
else:
1795+
ref = xp_searchsorted(x, y, side=side, xp=np)
1796+
ref = xp.asarray(ref, dtype=xp_default_int)
1797+
x, y = xp.asarray(x.copy()), xp.asarray(y.copy())
1798+
res = searchsorted(x, y, side=side, xp=xp)
1799+
xp_assert_equal(res, ref)

0 commit comments

Comments
 (0)