Skip to content

Commit 5089588

Browse files
committed
ENH: add "repro_snippets" to test_indexing_functions.py
1 parent cf3bf26 commit 5089588

File tree

1 file changed

+64
-53
lines changed

1 file changed

+64
-53
lines changed

array_api_tests/test_indexing_functions.py

Lines changed: 64 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -35,39 +35,43 @@ def test_take(x, data):
3535
indices = xp.asarray(_indices, dtype=dh.default_int)
3636
note(f"{indices=}")
3737

38-
out = xp.take(x, indices, **kw)
39-
40-
ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
41-
ph.assert_shape(
42-
"take",
43-
out_shape=out.shape,
44-
expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:],
45-
kw=dict(
46-
x=x,
47-
indices=indices,
48-
axis=axis,
49-
),
50-
)
51-
out_indices = sh.ndindex(out.shape)
52-
axis_indices = list(sh.axis_ndindex(x.shape, n_axis))
53-
for axis_idx in axis_indices:
54-
f_axis_idx = sh.fmt_idx("x", axis_idx)
55-
for i in _indices:
56-
f_take_idx = sh.fmt_idx(f_axis_idx, i)
57-
indexed_x = x[axis_idx][i, ...]
58-
for at_idx in sh.ndindex(indexed_x.shape):
59-
out_idx = next(out_indices)
60-
ph.assert_0d_equals(
61-
"take",
62-
x_repr=sh.fmt_idx(f_take_idx, at_idx),
63-
x_val=indexed_x[at_idx],
64-
out_repr=sh.fmt_idx("out", out_idx),
65-
out_val=out[out_idx],
66-
)
67-
# sanity check
68-
with pytest.raises(StopIteration):
69-
next(out_indices)
38+
repro_snippet = ph.format_snippet(f"xp.take({x!r}, {indices!r}, **kw) with {kw = }")
39+
try:
40+
out = xp.take(x, indices, **kw)
7041

42+
ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
43+
ph.assert_shape(
44+
"take",
45+
out_shape=out.shape,
46+
expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:],
47+
kw=dict(
48+
x=x,
49+
indices=indices,
50+
axis=axis,
51+
),
52+
)
53+
out_indices = sh.ndindex(out.shape)
54+
axis_indices = list(sh.axis_ndindex(x.shape, n_axis))
55+
for axis_idx in axis_indices:
56+
f_axis_idx = sh.fmt_idx("x", axis_idx)
57+
for i in _indices:
58+
f_take_idx = sh.fmt_idx(f_axis_idx, i)
59+
indexed_x = x[axis_idx][i, ...]
60+
for at_idx in sh.ndindex(indexed_x.shape):
61+
out_idx = next(out_indices)
62+
ph.assert_0d_equals(
63+
"take",
64+
x_repr=sh.fmt_idx(f_take_idx, at_idx),
65+
x_val=indexed_x[at_idx],
66+
out_repr=sh.fmt_idx("out", out_idx),
67+
out_val=out[out_idx],
68+
)
69+
# sanity check
70+
with pytest.raises(StopIteration):
71+
next(out_indices)
72+
except Exception as exc:
73+
exc.add_note(repro_snippet)
74+
raise
7175

7276
@pytest.mark.unvectorized
7377
@pytest.mark.min_version("2024.12")
@@ -103,26 +107,33 @@ def test_take_along_axis(x, data):
103107
)
104108
note(f"{indices=} {idx_shape=}")
105109

106-
out = xp.take_along_axis(x, indices, **axis_kw)
107-
108-
ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
109-
ph.assert_shape(
110-
"take_along_axis",
111-
out_shape=out.shape,
112-
expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
113-
kw=dict(
114-
x=x,
115-
indices=indices,
116-
axis=axis,
117-
),
110+
repro_snippet = ph.format_snippet(
111+
f"xp.take_along_axis({x!r}, {indices!r}, **axis_kw) with {axis_kw = }"
118112
)
113+
try:
114+
out = xp.take_along_axis(x, indices, **axis_kw)
115+
116+
ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
117+
ph.assert_shape(
118+
"take_along_axis",
119+
out_shape=out.shape,
120+
expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
121+
kw=dict(
122+
x=x,
123+
indices=indices,
124+
axis=axis,
125+
),
126+
)
119127

120-
# value test: notation is from `np.take_along_axis` docstring
121-
Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
122-
for ii in sh.ndindex(Ni):
123-
for kk in sh.ndindex(Nk):
124-
a_1d = x[ii + (slice(None),) + kk]
125-
i_1d = indices[ii + (slice(None),) + kk]
126-
o_1d = out[ii + (slice(None),) + kk]
127-
for j in range(new_len):
128-
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
128+
# value test: notation is from `np.take_along_axis` docstring
129+
Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
130+
for ii in sh.ndindex(Ni):
131+
for kk in sh.ndindex(Nk):
132+
a_1d = x[ii + (slice(None),) + kk]
133+
i_1d = indices[ii + (slice(None),) + kk]
134+
o_1d = out[ii + (slice(None),) + kk]
135+
for j in range(new_len):
136+
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
137+
except Exception as exc:
138+
exc.add_note(repro_snippet)
139+
raise

0 commit comments

Comments
 (0)