Skip to content

Commit ce55335

Browse files
committed
avoid sorting a; just sort the weights
1 parent 1b48267 commit ce55335

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

src/array_api_extra/_lib/_quantile.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -130,27 +130,28 @@ def _weighted_quantile( # numpydoc ignore=PR01,RT01
130130
sorter = xp.argsort(a, axis=-1, stable=False)
131131

132132
if a.ndim == 1:
133-
x = xp.take(a, sorter)
134-
w = xp.take(weights, sorter)
135-
return _weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device)
133+
return _weighted_quantile_sorted_1d(
134+
a, weights, sorter, q, n, average, nan_policy, xp, device
135+
)
136136

137137
(d,) = eager_shape(a, axis=0)
138138
res = []
139139
for idx in range(d):
140140
w = weights if weights.ndim == 1 else weights[idx, ...]
141-
w = xp.take(w, sorter[idx, ...])
142-
x = xp.take(a[idx, ...], sorter[idx, ...])
143141
res.append(
144-
_weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device)
142+
_weighted_quantile_sorted_1d(
143+
a[idx, ...], w, sorter[idx, ...], q, n, average, nan_policy, xp, device
144+
)
145145
)
146146

147147
return xp.stack(res, axis=1)
148148

149149

150150
def _weighted_quantile_sorted_1d( # numpydoc ignore=GL08
151151
x: Array,
152-
q: Array,
153152
w: Array,
153+
sorter: Array,
154+
q: Array,
154155
n: int,
155156
average: bool,
156157
nan_policy: str,
@@ -161,18 +162,25 @@ def _weighted_quantile_sorted_1d( # numpydoc ignore=GL08
161162
w = xp.where(xp.isnan(x), 0.0, w)
162163
elif xp.any(xp.isnan(x)):
163164
return xp.full(q.shape, xp.nan, dtype=x.dtype, device=device)
164-
cdf = xp.cumulative_sum(w)
165+
166+
cdf = xp.cumulative_sum(xp.take(w, sorter))
165167
t = cdf[-1] * q
168+
166169
i = xp.searchsorted(cdf, t, side="left")
167-
j = xp.searchsorted(cdf, t, side="right")
168170
i = xp.clip(i, 0, n - 1)
169-
j = xp.clip(j, 0, n - 1)
170-
171-
# Ignore leading `weights=0` observations when `q=0`
172-
# see https://github.com/scikit-learn/scikit-learn/pull/20528
173-
i = xp.where(q == 0.0, j, i)
174-
if average:
175-
# Ignore trailing `weights=0` observations when `q=1`
176-
j = xp.where(q == 1.0, i, j)
177-
return (xp.take(x, i) + xp.take(x, j)) / 2
171+
i = xp.take(sorter, i)
172+
173+
q0 = q == 0.0
174+
if average or xp.any(q0):
175+
j = xp.searchsorted(cdf, t, side="right")
176+
j = xp.clip(j, 0, n - 1)
177+
j = xp.take(sorter, j)
178+
# Ignore leading `weights=0` observations when `q=0`
179+
i = xp.where(q0, j, i)
180+
181+
if average:
182+
# Ignore trailing `weights=0` observations when `q=1`
183+
j = xp.where(q == 1.0, i, j)
184+
return (xp.take(x, i) + xp.take(x, j)) / 2
185+
178186
return xp.take(x, i)

tests/test_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1639,12 +1639,12 @@ def test_against_median_min_max(
16391639
for method in ["inverted_cdf", "averaged_inverted_cdf"]:
16401640
np_min = np.nanmin if nan_policy == "omit" else np.min
16411641
expected = np_min(a_np_med)
1642-
actual = quantile(a, 0., method=method, nan_policy=nan_policy, weights=w)
1642+
actual = quantile(a, 0.0, method=method, nan_policy=nan_policy, weights=w)
16431643
xp_assert_close(actual, xp.asarray(expected))
16441644

16451645
np_max = np.nanmax if nan_policy == "omit" else np.max
16461646
expected = np_max(a_np_med)
1647-
actual = quantile(a, 1., method=method, nan_policy=nan_policy, weights=w)
1647+
actual = quantile(a, 1.0, method=method, nan_policy=nan_policy, weights=w)
16481648
xp_assert_close(actual, xp.asarray(expected))
16491649

16501650
@pytest.mark.parametrize("keepdims", [True, False])

0 commit comments

Comments
 (0)