Skip to content

Commit ee3d9b7

Browse files
authored
Merge pull request #393 from ev-br/searchsorted_todos
expand testing of `searchsorted`
2 parents 0f04191 + 5bd524b commit ee3d9b7

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,37 +243,51 @@ def test_where(shapes, dtypes, data):
243243
@pytest.mark.min_version("2023.12")
244244
@given(data=st.data())
245245
def test_searchsorted(data):
246-
# TODO: test side="right"
247246
# TODO: Allow different dtypes for x1 and x2
247+
x1_dtype = data.draw(st.sampled_from(dh.real_dtypes))
248248
_x1 = data.draw(
249-
st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
249+
st.lists(
250+
xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False),
251+
min_size=1,
252+
unique=True
253+
),
250254
label="_x1",
251255
)
252-
x1 = xp.asarray(_x1, dtype=dh.default_float)
256+
x1 = xp.asarray(_x1, dtype=x1_dtype)
253257
if data.draw(st.booleans(), label="use sorter?"):
254258
sorter = xp.argsort(x1)
255259
else:
256260
sorter = None
257261
x1 = xp.sort(x1)
258262
note(f"{x1=}")
263+
259264
x2 = data.draw(
260265
st.lists(st.sampled_from(_x1), unique=True, min_size=1).map(
261-
lambda o: xp.asarray(o, dtype=dh.default_float)
266+
lambda o: xp.asarray(o, dtype=x1_dtype)
262267
),
263268
label="x2",
264269
)
270+
# make x2.ndim > 1, if it makes sense
271+
factors = hh._factorize(x2.shape[0])
272+
if len(factors) > 1:
273+
x2 = xp.reshape(x2, tuple(factors))
265274

266-
repro_snippet = ph.format_snippet(f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r})")
275+
kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"])))
276+
277+
repro_snippet = ph.format_snippet(
278+
f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw=}"
279+
)
267280
try:
268-
out = xp.searchsorted(x1, x2, sorter=sorter)
281+
out = xp.searchsorted(x1, x2, sorter=sorter, **kw)
269282

270283
ph.assert_dtype(
271284
"searchsorted",
272285
in_dtype=[x1.dtype, x2.dtype],
273286
out_dtype=out.dtype,
274287
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
275288
)
276-
# TODO: shapes and values testing
289+
# TODO: values testing
290+
ph.assert_shape("searchsorted", out_shape=out.shape, expected=x2.shape)
277291
except Exception as exc:
278292
ph.add_note(exc, repro_snippet)
279293
raise

0 commit comments

Comments
 (0)