Skip to content

Commit cf3bf26

Browse files
committed
ENH: add "repro_snippets" to test_utility_functions.py
1 parent b4038ce commit cf3bf26

File tree

1 file changed

+72
-50
lines changed

1 file changed

+72
-50
lines changed

array_api_tests/test_utility_functions.py

Lines changed: 72 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,28 @@ def test_all(x, data):
1818
kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
1919
keepdims = kw.get("keepdims", False)
2020

21-
out = xp.all(x, **kw)
22-
23-
ph.assert_dtype("all", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
24-
_axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
25-
ph.assert_keepdimable_shape(
26-
"all", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
27-
)
28-
scalar_type = dh.get_scalar_type(x.dtype)
29-
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
30-
result = bool(out[out_idx])
31-
elements = []
32-
for idx in indices:
33-
s = scalar_type(x[idx])
34-
elements.append(s)
35-
expected = all(elements)
36-
ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx,
37-
out=result, expected=expected, kw=kw)
21+
repro_snippet = ph.format_snippet(f"xp.all({x!r}, **kw) with {kw = }")
22+
try:
23+
out = xp.all(x, **kw)
24+
25+
ph.assert_dtype("all", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
26+
_axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
27+
ph.assert_keepdimable_shape(
28+
"all", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
29+
)
30+
scalar_type = dh.get_scalar_type(x.dtype)
31+
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
32+
result = bool(out[out_idx])
33+
elements = []
34+
for idx in indices:
35+
s = scalar_type(x[idx])
36+
elements.append(s)
37+
expected = all(elements)
38+
ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx,
39+
out=result, expected=expected, kw=kw)
40+
except Exception as exc:
41+
exc.add_note(repro_snippet)
42+
raise
3843

3944

4045
@pytest.mark.unvectorized
@@ -46,23 +51,28 @@ def test_any(x, data):
4651
kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
4752
keepdims = kw.get("keepdims", False)
4853

49-
out = xp.any(x, **kw)
50-
51-
ph.assert_dtype("any", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
52-
_axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
53-
ph.assert_keepdimable_shape(
54-
"any", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw,
55-
)
56-
scalar_type = dh.get_scalar_type(x.dtype)
57-
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
58-
result = bool(out[out_idx])
59-
elements = []
60-
for idx in indices:
61-
s = scalar_type(x[idx])
62-
elements.append(s)
63-
expected = any(elements)
64-
ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx,
65-
out=result, expected=expected, kw=kw)
54+
repro_snippet = ph.format_snippet(f"xp.any({x!r}, **kw) with {kw = }")
55+
try:
56+
out = xp.any(x, **kw)
57+
58+
ph.assert_dtype("any", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
59+
_axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
60+
ph.assert_keepdimable_shape(
61+
"any", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw,
62+
)
63+
scalar_type = dh.get_scalar_type(x.dtype)
64+
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
65+
result = bool(out[out_idx])
66+
elements = []
67+
for idx in indices:
68+
s = scalar_type(x[idx])
69+
elements.append(s)
70+
expected = any(elements)
71+
ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx,
72+
out=result, expected=expected, kw=kw)
73+
except Exception as exc:
74+
exc.add_note(repro_snippet)
75+
raise
6676

6777

6878
@pytest.mark.unvectorized
@@ -85,19 +95,24 @@ def test_diff(x, data):
8595

8696
n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
8797

88-
out = xp.diff(x, **axis_kw, n=n)
98+
repro_snippet = ph.format_snippet(f"xp.diff({x!r}, **axis_kw, n={n!r}) with {axis_kw = }")
99+
try:
100+
out = xp.diff(x, **axis_kw, n=n)
89101

90-
expected_shape = list(x.shape)
91-
expected_shape[n_axis] -= n
102+
expected_shape = list(x.shape)
103+
expected_shape[n_axis] -= n
92104

93-
assert out.shape == tuple(expected_shape)
105+
assert out.shape == tuple(expected_shape)
94106

95-
# value test
96-
if n == 1:
97-
for idx in sh.ndindex(out.shape):
98-
l = list(idx)
99-
l[n_axis] += 1
100-
assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }"
107+
# value test
108+
if n == 1:
109+
for idx in sh.ndindex(out.shape):
110+
l = list(idx)
111+
l[n_axis] += 1
112+
assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }"
113+
except Exception as exc:
114+
exc.add_note(repro_snippet)
115+
raise
101116

102117

103118
@pytest.mark.min_version("2024.12")
@@ -130,12 +145,19 @@ def test_diff_append_prepend(x, data):
130145
prepend_shape[n_axis] = prepend_axis_len
131146
prepend = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(prepend_shape)), label="prepend")
132147

133-
out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend)
148+
repro_snippet = ph.format_snippet(
149+
f"xp.diff({x!r}, **axis_kw, n={n!r}, append={append!r}, prepend={prepend!r}) with {axis_kw = }"
150+
)
151+
try:
152+
out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend)
134153

135-
in_1 = xp.concat((prepend, x, append), **axis_kw)
136-
out_1 = xp.diff(in_1, **axis_kw, n=n)
154+
in_1 = xp.concat((prepend, x, append), **axis_kw)
155+
out_1 = xp.diff(in_1, **axis_kw, n=n)
137156

138-
assert out.shape == out_1.shape
139-
for idx in sh.ndindex(out.shape):
140-
assert out[idx] == out_1[idx], f"{idx = }"
157+
assert out.shape == out_1.shape
158+
for idx in sh.ndindex(out.shape):
159+
assert out[idx] == out_1[idx], f"{idx = }"
160+
except Exception as exc:
161+
exc.add_note(repro_snippet)
162+
raise
141163

0 commit comments

Comments
 (0)