Skip to content

Commit 22ed172

Browse files
committed
add xfail with request fixture
1 parent 04f2a38 commit 22ed172

File tree

2 files changed

+57
-53
lines changed

2 files changed

+57
-53
lines changed

src/array_api_extra/_delegation.py

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ._lib._utils._typing import Array, DType
2020

2121
__all__ = [
22+
"atleast_nd",
2223
"cov",
2324
"expand_dims",
2425
"isclose",
@@ -29,6 +30,55 @@
2930
]
3031

3132

33+
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
34+
"""
35+
Recursively expand the dimension of an array to at least `ndim`.
36+
37+
Parameters
38+
----------
39+
x : array
40+
Input array.
41+
ndim : int
42+
The minimum number of dimensions for the result.
43+
xp : array_namespace, optional
44+
The standard-compatible namespace for `x`. Default: infer.
45+
46+
Returns
47+
-------
48+
array
49+
An array with ``res.ndim`` >= `ndim`.
50+
If ``x.ndim`` >= `ndim`, `x` is returned.
51+
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
52+
until ``res.ndim`` equals `ndim`.
53+
54+
Examples
55+
--------
56+
>>> import array_api_strict as xp
57+
>>> import array_api_extra as xpx
58+
>>> x = xp.asarray([1])
59+
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
60+
Array([[[1]]], dtype=array_api_strict.int64)
61+
62+
>>> x = xp.asarray([[[1, 2],
63+
... [3, 4]]])
64+
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
65+
True
66+
"""
67+
if xp is None:
68+
xp = array_namespace(x)
69+
70+
if 1 <= ndim <= 3 and (
71+
is_numpy_namespace(xp)
72+
or is_jax_namespace(xp)
73+
or is_dask_namespace(xp)
74+
or is_cupy_namespace(xp)
75+
or is_torch_namespace(xp)
76+
):
77+
return getattr(xp, f"atleast_{ndim}d")(x)
78+
79+
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
80+
81+
3282
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
3383
"""
3484
Estimate a covariance matrix.
@@ -197,55 +247,6 @@ def expand_dims(
197247
return _funcs.expand_dims(a, axis=axis, xp=xp)
198248

199249

200-
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
201-
"""
202-
Recursively expand the dimension of an array to at least `ndim`.
203-
204-
Parameters
205-
----------
206-
x : array
207-
Input array.
208-
ndim : int
209-
The minimum number of dimensions for the result.
210-
xp : array_namespace, optional
211-
The standard-compatible namespace for `x`. Default: infer.
212-
213-
Returns
214-
-------
215-
array
216-
An array with ``res.ndim`` >= `ndim`.
217-
If ``x.ndim`` >= `ndim`, `x` is returned.
218-
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
219-
until ``res.ndim`` equals `ndim`.
220-
221-
Examples
222-
--------
223-
>>> import array_api_strict as xp
224-
>>> import array_api_extra as xpx
225-
>>> x = xp.asarray([1])
226-
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
227-
Array([[[1]]], dtype=array_api_strict.int64)
228-
229-
>>> x = xp.asarray([[[1, 2],
230-
... [3, 4]]])
231-
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
232-
True
233-
"""
234-
if xp is None:
235-
xp = array_namespace(x)
236-
237-
if 1 <= ndim <= 3 and (
238-
is_numpy_namespace(xp)
239-
or is_jax_namespace(xp)
240-
or is_dask_namespace(xp)
241-
or is_cupy_namespace(xp)
242-
or is_torch_namespace(xp)
243-
):
244-
return getattr(xp, f"atleast_{ndim}d")(x)
245-
246-
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
247-
248-
249250
def isclose(
250251
a: Array | complex,
251252
b: Array | complex,

tests/test_funcs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
sinc,
3434
)
3535
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
36-
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
36+
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
3737
from array_api_extra._lib._utils._compat import device as get_device
3838
from array_api_extra._lib._utils._compat import is_jax_namespace
3939
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
@@ -1263,6 +1263,7 @@ def test_assume_unique(self, xp: ModuleType):
12631263
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
12641264
def test_shapes(
12651265
self,
1266+
request: pytest.FixtureRequest,
12661267
assume_unique: bool,
12671268
shape1: tuple[int, ...],
12681269
shape2: tuple[int, ...],
@@ -1272,22 +1273,24 @@ def test_shapes(
12721273
x2 = xp.zeros(shape2)
12731274

12741275
if is_jax_namespace(xp) and assume_unique and shape1 != (1,):
1275-
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")
1276+
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")
12761277

12771278
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12781279
xp_assert_equal(actual, xp.empty((0,)))
12791280

12801281
@assume_unique
12811282
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
1282-
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
1283+
def test_python_scalar(
1284+
self, request: pytest.FixtureRequest, xp: ModuleType, assume_unique: bool
1285+
):
12831286
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
12841287
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
12851288
x2 = 3
12861289
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12871290
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))
12881291

12891292
if is_jax_namespace(xp) and assume_unique:
1290-
pytest.xfail(reason="jax#32335 fixed with jax>=0.8.0")
1293+
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")
12911294

12921295
actual = setdiff1d(x2, x1, assume_unique=assume_unique)
12931296
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))

0 commit comments

Comments
 (0)