Skip to content

Commit c395b84

Browse files
committed
more coverage
1 parent 3226659 commit c395b84

File tree

2 files changed

+48
-11
lines changed

2 files changed

+48
-11
lines changed

src/array_api_extra/_delegation.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,9 +1070,6 @@ def quantile(
10701070
if nan_policy not in nan_policies:
10711071
msg = f"`nan_policy` must be one of {nan_policies}"
10721072
raise ValueError(msg)
1073-
if keepdims not in {True, False}:
1074-
msg = "If specified, `keepdims` must be True or False."
1075-
raise ValueError(msg)
10761073

10771074
a = xp.asarray(a)
10781075
if not xp.isdtype(a.dtype, ("integral", "real floating")):
@@ -1090,7 +1087,7 @@ def quantile(
10901087
raise ValueError(msg)
10911088
if weights is None:
10921089
if nan_policy != "propagate":
1093-
msg = ""
1090+
msg = "When `weights` aren't provided, `nan_policy` must be 'propagate'"
10941091
raise ValueError(msg)
10951092
else:
10961093
if ndim > 2:
@@ -1107,7 +1104,7 @@ def quantile(
11071104
)
11081105
raise ValueError(msg)
11091106
if axis is None and ndim == 2:
1110-
msg = "When weights are provided, axis must be specified when `a` is 2d"
1107+
msg = "Axis must be specified when `a` and ̀ weights` are 2d."
11111108
raise ValueError(msg)
11121109

11131110
# Align result dtype with what numpy does:

tests/test_funcs.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,6 +1543,12 @@ def test_basic(self, xp: ModuleType):
15431543
expect = xp.asarray(3.0, dtype=xp.float64)
15441544
xp_assert_close(actual, expect)
15451545

1546+
def test_xp(self, xp: ModuleType):
1547+
x = xp.asarray([1, 2, 3, 4, 5])
1548+
actual = quantile(x, 0.5, xp=xp)
1549+
expect = xp.asarray(3.0, dtype=xp.float64)
1550+
xp_assert_close(actual, expect)
1551+
15461552
def test_multiple_quantiles(self, xp: ModuleType):
15471553
x = xp.asarray([1, 2, 3, 4, 5])
15481554
actual = quantile(x, xp.asarray([0.25, 0.5, 0.75]))
@@ -1729,15 +1735,49 @@ def test_invalid_q(self, xp: ModuleType):
17291735
):
17301736
_ = quantile(x, -0.5)
17311737

1738+
def test_invalid_shape(self, xp: ModuleType):
1739+
with pytest.raises(TypeError, match="at least 1-dimensional"):
1740+
_ = quantile(xp.asarray(3.0), 0.5)
1741+
with pytest.raises(ValueError, match="not compatible with the dimension"):
1742+
_ = quantile(xp.asarray([3.0]), 0.5, axis=1)
1743+
# with weights:
1744+
method = "inverted_cdf"
1745+
shape = (2, 3, 4)
1746+
with pytest.raises(ValueError, match="dimension of `a` must be 1 or 2"):
1747+
_ = quantile(
1748+
xp.ones(shape), 0.5, axis=1, weights=xp.ones(shape), method=method
1749+
)
1750+
with pytest.raises(TypeError, match="Axis must be specified"):
1751+
_ = quantile(xp.ones((2, 3)), 0.5, weights=xp.ones(3), method=method)
1752+
with pytest.raises(ValueError, match="Shape of weights must be consistent"):
1753+
_ = quantile(
1754+
xp.ones((2, 3)), 0.5, axis=0, weights=xp.ones(3), method=method
1755+
)
1756+
with pytest.raises(ValueError, match="Axis must be specified"):
1757+
_ = quantile(xp.ones((2, 3)), 0.5, weights=xp.ones((2, 3)), method=method)
1758+
1759+
def test_invalid_dtype(self, xp: ModuleType):
1760+
with pytest.raises(ValueError, match="`a` must have real dtype"):
1761+
_ = quantile(xp.ones(5, dtype=xp.bool), 0.5)
1762+
1763+
with pytest.raises(ValueError, match="`q` must have real floating dtype"):
1764+
_ = quantile(xp.ones(5), xp.asarray([0, 1]))
1765+
1766+
def test_invalid_method(self, xp: ModuleType):
1767+
with pytest.raises(ValueError, match="`method` must be one of"):
1768+
_ = quantile(xp.ones(5), 0.5, method="invalid")
1769+
# TODO: with weights?
1770+
1771+
def test_invalid_nan_policy(self, xp: ModuleType):
1772+
with pytest.raises(ValueError, match="`nan_policy` must be one of"):
1773+
_ = quantile(xp.ones(5), 0.5, nan_policy="invalid")
1774+
1775+
with pytest.raises(ValueError, match="must be 'propagate'"):
1776+
_ = quantile(xp.ones(5), 0.5, nan_policy="omit")
1777+
17321778
def test_device(self, xp: ModuleType, device: Device):
17331779
if hasattr(device, "type") and device.type == "meta": # pyright: ignore[reportAttributeAccessIssue]
17341780
pytest.xfail("No Tensor.item() on meta device")
17351781
x = xp.asarray([1, 2, 3, 4, 5], device=device)
17361782
actual = quantile(x, 0.5)
17371783
assert get_device(actual) == device
1738-
1739-
def test_xp(self, xp: ModuleType):
1740-
x = xp.asarray([1, 2, 3, 4, 5])
1741-
actual = quantile(x, 0.5, xp=xp)
1742-
expect = xp.asarray(3.0, dtype=xp.float64)
1743-
xp_assert_close(actual, expect)

0 commit comments

Comments
 (0)