Skip to content

Commit 37dcc4b

Browse files
authored
ENH: new function union1d (#495)
1 parent a9dadb9 commit 37dcc4b

File tree

5 files changed

+80
-0
lines changed

5 files changed

+80
-0
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@
2525
partition
2626
setdiff1d
2727
sinc
28+
union1d
2829
```

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
partition,
1515
setdiff1d,
1616
sinc,
17+
union1d,
1718
)
1819
from ._lib._at import at
1920
from ._lib._funcs import (
@@ -50,4 +51,5 @@
5051
"partition",
5152
"setdiff1d",
5253
"sinc",
54+
"union1d",
5355
]

src/array_api_extra/_delegation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,3 +1026,37 @@ def isin(
10261026
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)
10271027

10281028
return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)
1029+
1030+
1031+
def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
1032+
"""
1033+
Find the union of two arrays.
1034+
1035+
Return the unique, sorted array of values that are in either of the two
1036+
input arrays.
1037+
1038+
Parameters
1039+
----------
1040+
a, b : Array
1041+
Input arrays. They are flattened internally if they are not already 1D.
1042+
1043+
xp : array_namespace, optional
1044+
The standard-compatible namespace for `a` and `b`. Default: infer.
1045+
1046+
Returns
1047+
-------
1048+
Array
1049+
Unique, sorted union of the input arrays.
1050+
"""
1051+
if xp is None:
1052+
xp = array_namespace(a, b)
1053+
1054+
if (
1055+
is_numpy_namespace(xp)
1056+
or is_cupy_namespace(xp)
1057+
or is_dask_namespace(xp)
1058+
or is_jax_namespace(xp)
1059+
):
1060+
return xp.union1d(a, b)
1061+
1062+
return _funcs.union1d(a, b, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,3 +742,12 @@ def isin( # numpydoc ignore=PR01,RT01
742742
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
743743
original_a_shape,
744744
)
745+
746+
747+
def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
748+
# numpydoc ignore=PR01,RT01
749+
"""See docstring in `array_api_extra._delegation.py`."""
750+
a = xp.reshape(a, (-1,))
751+
b = xp.reshape(b, (-1,))
752+
# XXX: `sparse` returns NumPy arrays from `unique_values`
753+
return xp.asarray(xp.unique_values(xp.concat([a, b])))

tests/test_funcs.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
partition,
3232
setdiff1d,
3333
sinc,
34+
union1d,
3435
)
3536
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3637
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
@@ -1637,3 +1638,36 @@ def test_kind(self, xp: ModuleType, library: Backend):
16371638
expected = xp.asarray([False, True, False, True])
16381639
res = isin(a, b, kind="sort")
16391640
xp_assert_equal(res, expected)
1641+
1642+
1643+
@pytest.mark.skip_xp_backend(
1644+
Backend.ARRAY_API_STRICTEST,
1645+
reason="data_dependent_shapes flag for unique_values is disabled",
1646+
)
1647+
class TestUnion1d:
1648+
def test_simple(self, xp: ModuleType):
1649+
a = xp.asarray([-1, 1, 0])
1650+
b = xp.asarray([2, -2, 0])
1651+
expected = xp.asarray([-2, -1, 0, 1, 2])
1652+
res = union1d(a, b)
1653+
xp_assert_equal(res, expected)
1654+
1655+
def test_2d(self, xp: ModuleType):
1656+
a = xp.asarray([[-1, 1, 0], [1, 2, 0]])
1657+
b = xp.asarray([[1, 0, 1], [-2, -1, 0]])
1658+
expected = xp.asarray([-2, -1, 0, 1, 2])
1659+
res = union1d(a, b)
1660+
xp_assert_equal(res, expected)
1661+
1662+
def test_3d(self, xp: ModuleType):
1663+
a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]])
1664+
b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]])
1665+
expected = xp.asarray([-2, -1, 0, 1, 2])
1666+
res = union1d(a, b)
1667+
xp_assert_equal(res, expected)
1668+
1669+
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
1670+
def test_device(self, xp: ModuleType, device: Device):
1671+
a = xp.asarray([-1, 1, 0], device=device)
1672+
b = xp.asarray([2, -2, 0], device=device)
1673+
assert get_device(union1d(a, b)) == device

0 commit comments

Comments
 (0)