Skip to content

Commit 3611708

Browse files
committed
linting & cleanup
1 parent 26804fe commit 3611708

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

src/array_api_extra/_delegation.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -938,13 +938,16 @@ def quantile(
938938
9. 'normal_unbiased'
939939
940940
The first three methods are discontinuous.
941-
Only 'linear' is implemented for now.
941+
Only 'linear', 'inverted_cdf' and 'averaged_inverted_cdf' are implemented.
942942
943943
keepdims : bool, optional
944944
If this is set to True, the axes which are reduced are left in
945945
the result as dimensions with size one. With this option, the
946946
result will broadcast correctly against the original array `a`.
947947
948+
nan_policy : str, optional
949+
'propagate' (default) or 'omit'.
950+
948951
weights : array_like, optional
949952
An array of weights associated with the values in `a`. Each value in
950953
`a` contributes to the quantile according to its associated weight.
@@ -1121,20 +1124,24 @@ def quantile(
11211124
msg = "`q` values must be in the range [0, 1]"
11221125
raise ValueError(msg)
11231126

1124-
# Delegate where possible.
1127+
# Delegate when possible.
11251128
if is_numpy_namespace(xp) and nan_policy == "propagate":
1129+
# TODO: call nanquantile for nan_policy == "omit" once
1130+
# https://github.com/numpy/numpy/issues/29709 is fixed
11261131
return xp.quantile(
11271132
a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights
11281133
)
1129-
# No delegation for dask: I couldn't make it work
1130-
basic_case = method == "linear" and weights is None and nan_policy == "propagate"
1131-
if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp):
1134+
# No delegation for dask: I couldn't make it work.
1135+
basic_case = method == "linear" and weights is None
1136+
jax_or_cupy = is_jax_namespace(xp) or is_cupy_namespace(xp)
1137+
if basic_case and nan_policy == "propagate" and jax_or_cupy:
11321138
return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims)
11331139
if basic_case and is_torch_namespace(xp):
1134-
return xp.quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims)
1140+
quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile
1141+
return quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims)
11351142

1136-
# XXX: I'm not sure we want to support dask, it seems uterly slow...
11371143
# Otherwise call our implementation (will sort data)
1144+
# XXX: I'm not sure we want to support dask, it seems uterly slow...
11381145
return _quantile.quantile(
11391146
a,
11401147
q_arr,

src/array_api_extra/_lib/_quantile.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ def quantile( # numpydoc ignore=PR01,RT01
7373
return res[0, ...] if q_scalar else res
7474

7575

76-
def _quantile( # numpydoc ignore=GL08
76+
def _quantile( # numpydoc ignore=PR01,RT01
7777
a: Array, q: Array, n: int, axis: int, method: str, xp: ModuleType
7878
) -> Array:
79+
"""Compute quantile by sorting `a`."""
7980
a = xp.sort(a, axis=axis, stable=False)
8081
mask_nan = xp.any(xp.isnan(a), axis=axis, keepdims=True)
8182
if xp.any(mask_nan):
@@ -114,7 +115,7 @@ def _quantile( # numpydoc ignore=GL08
114115
)
115116

116117

117-
def _weighted_quantile(
118+
def _weighted_quantile( # numpydoc ignore=PR01,RT01
118119
a: Array,
119120
q: Array,
120121
weights: Array,
@@ -126,7 +127,9 @@ def _weighted_quantile(
126127
device: Device,
127128
) -> Array:
128129
"""
129-
a is expected to be 1d or 2d.
130+
Compute weighted quantile using searchsorted on CDF.
131+
132+
`a` is expected to be 1d or 2d.
130133
"""
131134
a = xp.moveaxis(a, axis, -1)
132135
if weights.ndim > 1:
@@ -151,7 +154,7 @@ def _weighted_quantile(
151154
return xp.stack(res, axis=1)
152155

153156

154-
def _weighted_quantile_sorted_1d(
157+
def _weighted_quantile_sorted_1d( # numpydoc ignore=GL08
155158
x: Array,
156159
q: Array,
157160
w: Array,
@@ -165,10 +168,10 @@ def _weighted_quantile_sorted_1d(
165168
w = xp.where(xp.isnan(x), 0.0, w)
166169
elif xp.any(xp.isnan(x)):
167170
return xp.full(q.shape, xp.nan, dtype=x.dtype, device=device)
168-
cw = xp.cumulative_sum(w)
169-
t = cw[-1] * q
170-
i = xp.searchsorted(cw, t, side="left")
171-
j = xp.searchsorted(cw, t, side="right")
171+
cdf = xp.cumulative_sum(w)
172+
t = cdf[-1] * q
173+
i = xp.searchsorted(cdf, t, side="left")
174+
j = xp.searchsorted(cdf, t, side="right")
172175
i = xp.clip(i, 0, n - 1)
173176
j = xp.clip(j, 0, n - 1)
174177

0 commit comments

Comments
 (0)