Skip to content

Commit 0f04191

Browse files
authored
Merge pull request #396 from ev-br/repro_snippets_3
More "repro snippets"
2 parents b4038ce + 5acac26 commit 0f04191

File tree

4 files changed

+240
-180
lines changed

4 files changed

+240
-180
lines changed

.git-blame-ignore-revs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ e807ffe526c7330691e8f39d31347dc2b3106de3
77
bd42e84d2e5aae26ade8d70384e74effd1de89cb
88
f7e822883b7e24b5aa540e2413759a85128b42ef
99
a37f348ba27b6818e92fda8aee2406c653c671ea
10+
# gh-396
11+
ec5a3b4e185c262b0a5f5b1631b84a09f766d80e
12+
9058908b58ce627467ac34e768098a25f5863d31
13+
c80e1823c2e738381ca02f27cea1e2b89dde0ac5
14+

array_api_tests/test_array_object.py

Lines changed: 99 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def test_getitem(shape, dtype, data):
8686
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
8787

8888
repro_snippet = ph.format_snippet(f"{x!r}[{key!r}]")
89-
9089
try:
9190
out = x[key]
9291

@@ -109,6 +108,7 @@ def test_getitem(shape, dtype, data):
109108
ph.add_note(exc, repro_snippet)
110109
raise
111110

111+
112112
@pytest.mark.unvectorized
113113
@given(
114114
shape=hh.shapes(),
@@ -133,28 +133,34 @@ def test_setitem(shape, dtypes, data):
133133
value = data.draw(value_strat, label="value")
134134

135135
res = xp.asarray(x, copy=True)
136-
res[key] = value
137-
138-
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
139-
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
140-
f_res = sh.fmt_idx("x", key)
141-
if isinstance(value, get_args(Scalar)):
142-
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
143-
if cmath.isnan(value):
144-
assert xp.isnan(res[key]), msg
136+
137+
repro_snippet = ph.format_snippet(f"{res!r}[{key!r}] = {value!r}")
138+
try:
139+
res[key] = value
140+
141+
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
142+
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
143+
f_res = sh.fmt_idx("x", key)
144+
if isinstance(value, get_args(Scalar)):
145+
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
146+
if cmath.isnan(value):
147+
assert xp.isnan(res[key]), msg
148+
else:
149+
assert res[key] == value, msg
145150
else:
146-
assert res[key] == value, msg
147-
else:
148-
ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
149-
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
150-
for idx in unaffected_indices:
151-
ph.assert_0d_equals(
152-
"__setitem__",
153-
x_repr=f"old {f_res}",
154-
x_val=x[idx],
155-
out_repr=f"modified {f_res}",
156-
out_val=res[idx],
157-
)
151+
ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
152+
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
153+
for idx in unaffected_indices:
154+
ph.assert_0d_equals(
155+
"__setitem__",
156+
x_repr=f"old {f_res}",
157+
x_val=x[idx],
158+
out_repr=f"modified {f_res}",
159+
out_val=res[idx],
160+
)
161+
except Exception as exc:
162+
ph.add_note(exc, repro_snippet)
163+
raise
158164

159165

160166
@pytest.mark.unvectorized
@@ -178,29 +184,34 @@ def test_getitem_masking(shape, data):
178184
x[key]
179185
return
180186

181-
out = x[key]
187+
repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]")
188+
try:
189+
out = x[key]
182190

183-
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
184-
if key.ndim == 0:
185-
expected_shape = (1,) if key else (0,)
186-
expected_shape += x.shape
187-
else:
188-
size = int(xp.sum(xp.astype(key, xp.uint8)))
189-
expected_shape = (size,) + x.shape[key.ndim :]
190-
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
191-
if not any(s == 0 for s in key.shape):
192-
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
193-
out_indices = sh.ndindex(out.shape)
194-
for x_idx in sh.ndindex(x.shape):
195-
if key[x_idx]:
196-
out_idx = next(out_indices)
197-
ph.assert_0d_equals(
198-
"__getitem__",
199-
x_repr=f"x[{x_idx}]",
200-
x_val=x[x_idx],
201-
out_repr=f"out[{out_idx}]",
202-
out_val=out[out_idx],
203-
)
191+
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
192+
if key.ndim == 0:
193+
expected_shape = (1,) if key else (0,)
194+
expected_shape += x.shape
195+
else:
196+
size = int(xp.sum(xp.astype(key, xp.uint8)))
197+
expected_shape = (size,) + x.shape[key.ndim :]
198+
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
199+
if not any(s == 0 for s in key.shape):
200+
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
201+
out_indices = sh.ndindex(out.shape)
202+
for x_idx in sh.ndindex(x.shape):
203+
if key[x_idx]:
204+
out_idx = next(out_indices)
205+
ph.assert_0d_equals(
206+
"__getitem__",
207+
x_repr=f"x[{x_idx}]",
208+
x_val=x[x_idx],
209+
out_repr=f"out[{out_idx}]",
210+
out_val=out[out_idx],
211+
)
212+
except Exception as exc:
213+
ph.add_note(exc, repro_snippet)
214+
raise
204215

205216

206217
@pytest.mark.unvectorized
@@ -213,38 +224,44 @@ def test_setitem_masking(shape, data):
213224
)
214225

215226
res = xp.asarray(x, copy=True)
216-
res[key] = value
217-
218-
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
219-
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
220-
scalar_type = dh.get_scalar_type(x.dtype)
221-
for idx in sh.ndindex(x.shape):
222-
if key[idx]:
223-
if isinstance(value, get_args(Scalar)):
224-
ph.assert_scalar_equals(
225-
"__setitem__",
226-
type_=scalar_type,
227-
idx=idx,
228-
out=scalar_type(res[idx]),
229-
expected=value,
230-
repr_name="modified x",
231-
)
227+
228+
repro_snippet = ph.format_snippet(f"{res}[{key!r}] = {value!r}")
229+
try:
230+
res[key] = value
231+
232+
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
233+
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
234+
scalar_type = dh.get_scalar_type(x.dtype)
235+
for idx in sh.ndindex(x.shape):
236+
if key[idx]:
237+
if isinstance(value, get_args(Scalar)):
238+
ph.assert_scalar_equals(
239+
"__setitem__",
240+
type_=scalar_type,
241+
idx=idx,
242+
out=scalar_type(res[idx]),
243+
expected=value,
244+
repr_name="modified x",
245+
)
246+
else:
247+
ph.assert_0d_equals(
248+
"__setitem__",
249+
x_repr="value",
250+
x_val=value,
251+
out_repr=f"modified x[{idx}]",
252+
out_val=res[idx]
253+
)
232254
else:
233255
ph.assert_0d_equals(
234256
"__setitem__",
235-
x_repr="value",
236-
x_val=value,
257+
x_repr=f"old x[{idx}]",
258+
x_val=x[idx],
237259
out_repr=f"modified x[{idx}]",
238260
out_val=res[idx]
239261
)
240-
else:
241-
ph.assert_0d_equals(
242-
"__setitem__",
243-
x_repr=f"old x[{idx}]",
244-
x_val=x[idx],
245-
out_repr=f"modified x[{idx}]",
246-
out_val=res[idx]
247-
)
262+
except Exception as exc:
263+
ph.add_note(exc, repro_snippet)
264+
raise
248265

249266

250267
# ### Fancy indexing ###
@@ -309,15 +326,20 @@ def _test_getitem_arrays_and_ints(shape, data, idx_max_dims):
309326
key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
310327

311328
key = tuple(key)
312-
out = x[key]
329+
repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]")
330+
try:
331+
out = x[key]
313332

314-
arrays = [xp.asarray(k) for k in key]
315-
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
316-
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
333+
arrays = [xp.asarray(k) for k in key]
334+
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
335+
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
317336

318-
for idx in sh.ndindex(bcast_shape):
319-
tpl = tuple(k[idx] for k in bcast_key)
320-
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
337+
for idx in sh.ndindex(bcast_shape):
338+
tpl = tuple(k[idx] for k in bcast_key)
339+
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
340+
except Exception as exc:
341+
ph.add_note(exc, repro_snippet)
342+
raise
321343

322344

323345
def make_scalar_casting_param(

array_api_tests/test_indexing_functions.py

Lines changed: 64 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -35,39 +35,43 @@ def test_take(x, data):
3535
indices = xp.asarray(_indices, dtype=dh.default_int)
3636
note(f"{indices=}")
3737

38-
out = xp.take(x, indices, **kw)
39-
40-
ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
41-
ph.assert_shape(
42-
"take",
43-
out_shape=out.shape,
44-
expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:],
45-
kw=dict(
46-
x=x,
47-
indices=indices,
48-
axis=axis,
49-
),
50-
)
51-
out_indices = sh.ndindex(out.shape)
52-
axis_indices = list(sh.axis_ndindex(x.shape, n_axis))
53-
for axis_idx in axis_indices:
54-
f_axis_idx = sh.fmt_idx("x", axis_idx)
55-
for i in _indices:
56-
f_take_idx = sh.fmt_idx(f_axis_idx, i)
57-
indexed_x = x[axis_idx][i, ...]
58-
for at_idx in sh.ndindex(indexed_x.shape):
59-
out_idx = next(out_indices)
60-
ph.assert_0d_equals(
61-
"take",
62-
x_repr=sh.fmt_idx(f_take_idx, at_idx),
63-
x_val=indexed_x[at_idx],
64-
out_repr=sh.fmt_idx("out", out_idx),
65-
out_val=out[out_idx],
66-
)
67-
# sanity check
68-
with pytest.raises(StopIteration):
69-
next(out_indices)
38+
repro_snippet = ph.format_snippet(f"xp.take({x!r}, {indices!r}, **kw) with {kw = }")
39+
try:
40+
out = xp.take(x, indices, **kw)
7041

42+
ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
43+
ph.assert_shape(
44+
"take",
45+
out_shape=out.shape,
46+
expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:],
47+
kw=dict(
48+
x=x,
49+
indices=indices,
50+
axis=axis,
51+
),
52+
)
53+
out_indices = sh.ndindex(out.shape)
54+
axis_indices = list(sh.axis_ndindex(x.shape, n_axis))
55+
for axis_idx in axis_indices:
56+
f_axis_idx = sh.fmt_idx("x", axis_idx)
57+
for i in _indices:
58+
f_take_idx = sh.fmt_idx(f_axis_idx, i)
59+
indexed_x = x[axis_idx][i, ...]
60+
for at_idx in sh.ndindex(indexed_x.shape):
61+
out_idx = next(out_indices)
62+
ph.assert_0d_equals(
63+
"take",
64+
x_repr=sh.fmt_idx(f_take_idx, at_idx),
65+
x_val=indexed_x[at_idx],
66+
out_repr=sh.fmt_idx("out", out_idx),
67+
out_val=out[out_idx],
68+
)
69+
# sanity check
70+
with pytest.raises(StopIteration):
71+
next(out_indices)
72+
except Exception as exc:
73+
ph.add_note(exc, repro_snippet)
74+
raise
7175

7276
@pytest.mark.unvectorized
7377
@pytest.mark.min_version("2024.12")
@@ -103,26 +107,33 @@ def test_take_along_axis(x, data):
103107
)
104108
note(f"{indices=} {idx_shape=}")
105109

106-
out = xp.take_along_axis(x, indices, **axis_kw)
107-
108-
ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
109-
ph.assert_shape(
110-
"take_along_axis",
111-
out_shape=out.shape,
112-
expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
113-
kw=dict(
114-
x=x,
115-
indices=indices,
116-
axis=axis,
117-
),
110+
repro_snippet = ph.format_snippet(
111+
f"xp.take_along_axis({x!r}, {indices!r}, **axis_kw) with {axis_kw = }"
118112
)
113+
try:
114+
out = xp.take_along_axis(x, indices, **axis_kw)
115+
116+
ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
117+
ph.assert_shape(
118+
"take_along_axis",
119+
out_shape=out.shape,
120+
expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
121+
kw=dict(
122+
x=x,
123+
indices=indices,
124+
axis=axis,
125+
),
126+
)
119127

120-
# value test: notation is from `np.take_along_axis` docstring
121-
Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
122-
for ii in sh.ndindex(Ni):
123-
for kk in sh.ndindex(Nk):
124-
a_1d = x[ii + (slice(None),) + kk]
125-
i_1d = indices[ii + (slice(None),) + kk]
126-
o_1d = out[ii + (slice(None),) + kk]
127-
for j in range(new_len):
128-
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
128+
# value test: notation is from `np.take_along_axis` docstring
129+
Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
130+
for ii in sh.ndindex(Ni):
131+
for kk in sh.ndindex(Nk):
132+
a_1d = x[ii + (slice(None),) + kk]
133+
i_1d = indices[ii + (slice(None),) + kk]
134+
o_1d = out[ii + (slice(None),) + kk]
135+
for j in range(new_len):
136+
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
137+
except Exception as exc:
138+
ph.add_note(exc, repro_snippet)
139+
raise

0 commit comments

Comments
 (0)