Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions scipy/stats/tests/test_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def _get_weights_x_rep(self, x, axis, rng):
def test_against_numpy(self, method, shape_x, shape_p, axis, weights, xp):
if weights and method.startswith('_'):
pytest.skip('`weights=True` not supported by private (legacy) methods.')

if weights and is_numpy(xp) and xp.__version__ < "2.0":
pytest.skip('`weights` not supported by NumPy < 2.0.')
dtype = xp_default_dtype(xp)
rng = np.random.default_rng(23458924568734956)
x = rng.random(size=shape_x)
Expand All @@ -139,8 +140,8 @@ def test_against_numpy(self, method, shape_x, shape_p, axis, weights, xp):
ref = np.quantile(x_rep, p, axis=axis,
method=method[1:] if method.startswith('_') else method)

x, p = xp.asarray(x), xp.asarray(p)
weights = weights if weights is None else xp.asarray(weights)
x, p = xp.asarray(x, dtype=dtype), xp.asarray(p, dtype=dtype)
weights = weights if weights is None else xp.asarray(weights, dtype=dtype)
res = stats.quantile(x, p, method=method, weights=weights, axis=axis)

xp_assert_close(res, xp.asarray(ref, dtype=dtype))
Expand Down
Loading