Skip to content

Commit 8ab7d62

Browse files
committed
more validation
1 parent e319529 commit 8ab7d62

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/array_api_extra/_delegation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,8 @@ def quantile(
10781078
if not xp.isdtype(xp.asarray(q).dtype, "real floating"):
10791079
msg = "`q` must have real floating dtype."
10801080
raise ValueError(msg)
1081+
weights = None if weights is None else xp.asarray(weights)
1082+
10811083
ndim = a.ndim
10821084
if ndim < 1:
10831085
msg = "`a` must be at least 1-dimensional."
@@ -1087,9 +1089,15 @@ def quantile(
10871089
raise ValueError(msg)
10881090
if weights is None:
10891091
if nan_policy != "propagate":
1090-
msg = "When `weights` aren't provided, `nan_policy` must be 'propagate'"
1092+
msg = "When `weights` aren't provided, `nan_policy` must be 'propagate'."
10911093
raise ValueError(msg)
10921094
else:
1095+
if method not in {"inverted_cdf", "averaged_inverted_cdf"}:
1096+
msg = f"`method` '{method}' not supported with weights."
1097+
raise ValueError(msg)
1098+
if not xp.isdtype(weights.dtype, ("integral", "real floating")):
1099+
msg = "`weights` must have real dtype."
1100+
raise ValueError(msg)
10931101
if ndim > 2:
10941102
msg = "When weights are provided, dimension of `a` must be 1 or 2."
10951103
raise ValueError(msg)

tests/test_funcs.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,7 +1623,7 @@ def test_against_median(
16231623
a_np[rng.random(n) < rng.random(n) * 0.5] = np.nan
16241624
if w_np is not None:
16251625
# ensure at least one NaN on non-null weight:
1626-
nz_weights_idx, = np.where(w_np > 0)
1626+
(nz_weights_idx,) = np.where(w_np > 0)
16271627
a_np[nz_weights_idx[0]] = np.nan
16281628
m = "averaged_inverted_cdf"
16291629

@@ -1747,31 +1747,42 @@ def test_invalid_shape(self, xp: ModuleType):
17471747
_ = quantile(xp.asarray([3.0]), 0.5, axis=1)
17481748
# with weights:
17491749
method = "inverted_cdf"
1750+
17501751
shape = (2, 3, 4)
17511752
with pytest.raises(ValueError, match="dimension of `a` must be 1 or 2"):
17521753
_ = quantile(
17531754
xp.ones(shape), 0.5, axis=1, weights=xp.ones(shape), method=method
17541755
)
1756+
17551757
with pytest.raises(TypeError, match="Axis must be specified"):
17561758
_ = quantile(xp.ones((2, 3)), 0.5, weights=xp.ones(3), method=method)
1759+
17571760
with pytest.raises(ValueError, match="Shape of weights must be consistent"):
17581761
_ = quantile(
17591762
xp.ones((2, 3)), 0.5, axis=0, weights=xp.ones(3), method=method
17601763
)
1764+
17611765
with pytest.raises(ValueError, match="Axis must be specified"):
17621766
_ = quantile(xp.ones((2, 3)), 0.5, weights=xp.ones((2, 3)), method=method)
17631767

17641768
def test_invalid_dtype(self, xp: ModuleType):
17651769
with pytest.raises(ValueError, match="`a` must have real dtype"):
17661770
_ = quantile(xp.ones(5, dtype=xp.bool), 0.5)
17671771

1772+
a = xp.ones(5)
17681773
with pytest.raises(ValueError, match="`q` must have real floating dtype"):
1769-
_ = quantile(xp.ones(5), xp.asarray([0, 1]))
1774+
_ = quantile(a, xp.asarray([0, 1]))
1775+
1776+
weights = xp.ones(5, dtype=xp.bool)
1777+
with pytest.raises(ValueError, match="`weights` must have real dtype"):
1778+
_ = quantile(a, 0.5, weights=weights, method="inverted_cdf")
17701779

17711780
def test_invalid_method(self, xp: ModuleType):
17721781
with pytest.raises(ValueError, match="`method` must be one of"):
17731782
_ = quantile(xp.ones(5), 0.5, method="invalid")
1774-
# TODO: with weights?
1783+
1784+
with pytest.raises(ValueError, match="not supported with weights"):
1785+
_ = quantile(xp.ones(5), 0.5, method="linear", weights=xp.ones(5))
17751786

17761787
def test_invalid_nan_policy(self, xp: ModuleType):
17771788
with pytest.raises(ValueError, match="`nan_policy` must be one of"):

0 commit comments

Comments
 (0)