@@ -243,37 +243,51 @@ def test_where(shapes, dtypes, data):
243243@pytest .mark .min_version ("2023.12" )
244244@given (data = st .data ())
245245def test_searchsorted (data ):
246- # TODO: test side="right"
247246 # TODO: Allow different dtypes for x1 and x2
247+ x1_dtype = data .draw (st .sampled_from (dh .real_dtypes ))
248248 _x1 = data .draw (
249- st .lists (xps .from_dtype (dh .default_float ), min_size = 1 , unique = True ),
249+ st .lists (
250+ xps .from_dtype (x1_dtype , allow_nan = False , allow_infinity = False ),
251+ min_size = 1 ,
252+ unique = True
253+ ),
250254 label = "_x1" ,
251255 )
252- x1 = xp .asarray (_x1 , dtype = dh . default_float )
256+ x1 = xp .asarray (_x1 , dtype = x1_dtype )
253257 if data .draw (st .booleans (), label = "use sorter?" ):
254258 sorter = xp .argsort (x1 )
255259 else :
256260 sorter = None
257261 x1 = xp .sort (x1 )
258262 note (f"{ x1 = } " )
263+
259264 x2 = data .draw (
260265 st .lists (st .sampled_from (_x1 ), unique = True , min_size = 1 ).map (
261- lambda o : xp .asarray (o , dtype = dh . default_float )
266+ lambda o : xp .asarray (o , dtype = x1_dtype )
262267 ),
263268 label = "x2" ,
264269 )
270+ # make x2.ndim > 1, if it makes sense
271+ factors = hh ._factorize (x2 .shape [0 ])
272+ if len (factors ) > 1 :
273+ x2 = xp .reshape (x2 , tuple (factors ))
265274
266- repro_snippet = ph .format_snippet (f"xp.searchsorted({ x1 !r} , { x2 !r} , sorter={ sorter !r} )" )
275+ kw = data .draw (hh .kwargs (side = st .sampled_from (["left" , "right" ])))
276+
277+ repro_snippet = ph .format_snippet (
278+ f"xp.searchsorted({ x1 !r} , { x2 !r} , sorter={ sorter !r} , **kw) with { kw = } "
279+ )
267280 try :
268- out = xp .searchsorted (x1 , x2 , sorter = sorter )
281+ out = xp .searchsorted (x1 , x2 , sorter = sorter , ** kw )
269282
270283 ph .assert_dtype (
271284 "searchsorted" ,
272285 in_dtype = [x1 .dtype , x2 .dtype ],
273286 out_dtype = out .dtype ,
274287 expected = xp .__array_namespace_info__ ().default_dtypes ()["indexing" ],
275288 )
276- # TODO: shapes and values testing
289+ # TODO: values testing
290+ ph .assert_shape ("searchsorted" , out_shape = out .shape , expected = x2 .shape )
277291 except Exception as exc :
278292 ph .add_note (exc , repro_snippet )
279293 raise
0 commit comments