@@ -130,27 +130,28 @@ def _weighted_quantile( # numpydoc ignore=PR01,RT01
130130 sorter = xp .argsort (a , axis = - 1 , stable = False )
131131
132132 if a .ndim == 1 :
133- x = xp . take ( a , sorter )
134- w = xp . take ( weights , sorter )
135- return _weighted_quantile_sorted_1d ( x , q , w , n , average , nan_policy , xp , device )
133+ return _weighted_quantile_sorted_1d (
134+ a , weights , sorter , q , n , average , nan_policy , xp , device
135+ )
136136
137137 (d ,) = eager_shape (a , axis = 0 )
138138 res = []
139139 for idx in range (d ):
140140 w = weights if weights .ndim == 1 else weights [idx , ...]
141- w = xp .take (w , sorter [idx , ...])
142- x = xp .take (a [idx , ...], sorter [idx , ...])
143141 res .append (
144- _weighted_quantile_sorted_1d (x , q , w , n , average , nan_policy , xp , device )
142+ _weighted_quantile_sorted_1d (
143+ a [idx , ...], w , sorter [idx , ...], q , n , average , nan_policy , xp , device
144+ )
145145 )
146146
147147 return xp .stack (res , axis = 1 )
148148
149149
150150def _weighted_quantile_sorted_1d ( # numpydoc ignore=GL08
151151 x : Array ,
152- q : Array ,
153152 w : Array ,
153+ sorter : Array ,
154+ q : Array ,
154155 n : int ,
155156 average : bool ,
156157 nan_policy : str ,
@@ -161,18 +162,25 @@ def _weighted_quantile_sorted_1d( # numpydoc ignore=GL08
161162 w = xp .where (xp .isnan (x ), 0.0 , w )
162163 elif xp .any (xp .isnan (x )):
163164 return xp .full (q .shape , xp .nan , dtype = x .dtype , device = device )
164- cdf = xp .cumulative_sum (w )
165+
166+ cdf = xp .cumulative_sum (xp .take (w , sorter ))
165167 t = cdf [- 1 ] * q
168+
166169 i = xp .searchsorted (cdf , t , side = "left" )
167- j = xp .searchsorted (cdf , t , side = "right" )
168170 i = xp .clip (i , 0 , n - 1 )
169- j = xp .clip (j , 0 , n - 1 )
170-
171- # Ignore leading `weights=0` observations when `q=0`
172- # see https://github.com/scikit-learn/scikit-learn/pull/20528
173- i = xp .where (q == 0.0 , j , i )
174- if average :
175- # Ignore trailing `weights=0` observations when `q=1`
176- j = xp .where (q == 1.0 , i , j )
177- return (xp .take (x , i ) + xp .take (x , j )) / 2
171+ i = xp .take (sorter , i )
172+
173+ q0 = q == 0.0
174+ if average or xp .any (q0 ):
175+ j = xp .searchsorted (cdf , t , side = "right" )
176+ j = xp .clip (j , 0 , n - 1 )
177+ j = xp .take (sorter , j )
178+ # Ignore leading `weights=0` observations when `q=0`
179+ i = xp .where (q0 , j , i )
180+
181+ if average :
182+ # Ignore trailing `weights=0` observations when `q=1`
183+ j = xp .where (q == 1.0 , i , j )
184+ return (xp .take (x , i ) + xp .take (x , j )) / 2
185+
178186 return xp .take (x , i )
0 commit comments