Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ e807ffe526c7330691e8f39d31347dc2b3106de3
bd42e84d2e5aae26ade8d70384e74effd1de89cb
f7e822883b7e24b5aa540e2413759a85128b42ef
a37f348ba27b6818e92fda8aee2406c653c671ea
# gh-396
ec5a3b4e185c262b0a5f5b1631b84a09f766d80e
9058908b58ce627467ac34e768098a25f5863d31
c80e1823c2e738381ca02f27cea1e2b89dde0ac5

176 changes: 99 additions & 77 deletions array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def test_getitem(shape, dtype, data):
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")

repro_snippet = ph.format_snippet(f"{x!r}[{key!r}]")

try:
out = x[key]

Expand All @@ -109,6 +108,7 @@ def test_getitem(shape, dtype, data):
ph.add_note(exc, repro_snippet)
raise


@pytest.mark.unvectorized
@given(
shape=hh.shapes(),
Expand All @@ -133,28 +133,34 @@ def test_setitem(shape, dtypes, data):
value = data.draw(value_strat, label="value")

res = xp.asarray(x, copy=True)
res[key] = value

ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
f_res = sh.fmt_idx("x", key)
if isinstance(value, get_args(Scalar)):
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
if cmath.isnan(value):
assert xp.isnan(res[key]), msg

repro_snippet = ph.format_snippet(f"{res!r}[{key!r}] = {value!r}")
try:
res[key] = value

ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
f_res = sh.fmt_idx("x", key)
if isinstance(value, get_args(Scalar)):
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
if cmath.isnan(value):
assert xp.isnan(res[key]), msg
else:
assert res[key] == value, msg
else:
assert res[key] == value, msg
else:
ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
for idx in unaffected_indices:
ph.assert_0d_equals(
"__setitem__",
x_repr=f"old {f_res}",
x_val=x[idx],
out_repr=f"modified {f_res}",
out_val=res[idx],
)
ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
for idx in unaffected_indices:
ph.assert_0d_equals(
"__setitem__",
x_repr=f"old {f_res}",
x_val=x[idx],
out_repr=f"modified {f_res}",
out_val=res[idx],
)
except Exception as exc:
ph.add_note(exc, repro_snippet)
raise


@pytest.mark.unvectorized
Expand All @@ -178,29 +184,34 @@ def test_getitem_masking(shape, data):
x[key]
return

out = x[key]
repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]")
try:
out = x[key]

ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
if key.ndim == 0:
expected_shape = (1,) if key else (0,)
expected_shape += x.shape
else:
size = int(xp.sum(xp.astype(key, xp.uint8)))
expected_shape = (size,) + x.shape[key.ndim :]
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
if not any(s == 0 for s in key.shape):
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
out_indices = sh.ndindex(out.shape)
for x_idx in sh.ndindex(x.shape):
if key[x_idx]:
out_idx = next(out_indices)
ph.assert_0d_equals(
"__getitem__",
x_repr=f"x[{x_idx}]",
x_val=x[x_idx],
out_repr=f"out[{out_idx}]",
out_val=out[out_idx],
)
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
if key.ndim == 0:
expected_shape = (1,) if key else (0,)
expected_shape += x.shape
else:
size = int(xp.sum(xp.astype(key, xp.uint8)))
expected_shape = (size,) + x.shape[key.ndim :]
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
if not any(s == 0 for s in key.shape):
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
out_indices = sh.ndindex(out.shape)
for x_idx in sh.ndindex(x.shape):
if key[x_idx]:
out_idx = next(out_indices)
ph.assert_0d_equals(
"__getitem__",
x_repr=f"x[{x_idx}]",
x_val=x[x_idx],
out_repr=f"out[{out_idx}]",
out_val=out[out_idx],
)
except Exception as exc:
ph.add_note(exc, repro_snippet)
raise


@pytest.mark.unvectorized
Expand All @@ -213,38 +224,44 @@ def test_setitem_masking(shape, data):
)

res = xp.asarray(x, copy=True)
res[key] = value

ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
scalar_type = dh.get_scalar_type(x.dtype)
for idx in sh.ndindex(x.shape):
if key[idx]:
if isinstance(value, get_args(Scalar)):
ph.assert_scalar_equals(
"__setitem__",
type_=scalar_type,
idx=idx,
out=scalar_type(res[idx]),
expected=value,
repr_name="modified x",
)

repro_snippet = ph.format_snippet(f"{res}[{key!r}] = {value!r}")
try:
res[key] = value

ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
scalar_type = dh.get_scalar_type(x.dtype)
for idx in sh.ndindex(x.shape):
if key[idx]:
if isinstance(value, get_args(Scalar)):
ph.assert_scalar_equals(
"__setitem__",
type_=scalar_type,
idx=idx,
out=scalar_type(res[idx]),
expected=value,
repr_name="modified x",
)
else:
ph.assert_0d_equals(
"__setitem__",
x_repr="value",
x_val=value,
out_repr=f"modified x[{idx}]",
out_val=res[idx]
)
else:
ph.assert_0d_equals(
"__setitem__",
x_repr="value",
x_val=value,
x_repr=f"old x[{idx}]",
x_val=x[idx],
out_repr=f"modified x[{idx}]",
out_val=res[idx]
)
else:
ph.assert_0d_equals(
"__setitem__",
x_repr=f"old x[{idx}]",
x_val=x[idx],
out_repr=f"modified x[{idx}]",
out_val=res[idx]
)
except Exception as exc:
ph.add_note(exc, repro_snippet)
raise


# ### Fancy indexing ###
Expand Down Expand Up @@ -309,15 +326,20 @@ def _test_getitem_arrays_and_ints(shape, data, idx_max_dims):
key.append(data.draw(st.integers(-shape[i], shape[i]-1)))

key = tuple(key)
out = x[key]
repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]")
try:
out = x[key]

arrays = [xp.asarray(k) for k in key]
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
arrays = [xp.asarray(k) for k in key]
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]

for idx in sh.ndindex(bcast_shape):
tpl = tuple(k[idx] for k in bcast_key)
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
for idx in sh.ndindex(bcast_shape):
tpl = tuple(k[idx] for k in bcast_key)
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
except Exception as exc:
ph.add_note(exc, repro_snippet)
raise


def make_scalar_casting_param(
Expand Down
117 changes: 64 additions & 53 deletions array_api_tests/test_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,39 +35,43 @@ def test_take(x, data):
indices = xp.asarray(_indices, dtype=dh.default_int)
note(f"{indices=}")

out = xp.take(x, indices, **kw)

ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape(
"take",
out_shape=out.shape,
expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:],
kw=dict(
x=x,
indices=indices,
axis=axis,
),
)
out_indices = sh.ndindex(out.shape)
axis_indices = list(sh.axis_ndindex(x.shape, n_axis))
for axis_idx in axis_indices:
f_axis_idx = sh.fmt_idx("x", axis_idx)
for i in _indices:
f_take_idx = sh.fmt_idx(f_axis_idx, i)
indexed_x = x[axis_idx][i, ...]
for at_idx in sh.ndindex(indexed_x.shape):
out_idx = next(out_indices)
ph.assert_0d_equals(
"take",
x_repr=sh.fmt_idx(f_take_idx, at_idx),
x_val=indexed_x[at_idx],
out_repr=sh.fmt_idx("out", out_idx),
out_val=out[out_idx],
)
# sanity check
with pytest.raises(StopIteration):
next(out_indices)
repro_snippet = ph.format_snippet(f"xp.take({x!r}, {indices!r}, **kw) with {kw = }")
try:
out = xp.take(x, indices, **kw)

ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape(
"take",
out_shape=out.shape,
expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:],
kw=dict(
x=x,
indices=indices,
axis=axis,
),
)
out_indices = sh.ndindex(out.shape)
axis_indices = list(sh.axis_ndindex(x.shape, n_axis))
for axis_idx in axis_indices:
f_axis_idx = sh.fmt_idx("x", axis_idx)
for i in _indices:
f_take_idx = sh.fmt_idx(f_axis_idx, i)
indexed_x = x[axis_idx][i, ...]
for at_idx in sh.ndindex(indexed_x.shape):
out_idx = next(out_indices)
ph.assert_0d_equals(
"take",
x_repr=sh.fmt_idx(f_take_idx, at_idx),
x_val=indexed_x[at_idx],
out_repr=sh.fmt_idx("out", out_idx),
out_val=out[out_idx],
)
# sanity check
with pytest.raises(StopIteration):
next(out_indices)
except Exception as exc:
ph.add_note(exc, repro_snippet)
raise

@pytest.mark.unvectorized
@pytest.mark.min_version("2024.12")
Expand Down Expand Up @@ -103,26 +107,33 @@ def test_take_along_axis(x, data):
)
note(f"{indices=} {idx_shape=}")

out = xp.take_along_axis(x, indices, **axis_kw)

ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape(
"take_along_axis",
out_shape=out.shape,
expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
kw=dict(
x=x,
indices=indices,
axis=axis,
),
repro_snippet = ph.format_snippet(
f"xp.take_along_axis({x!r}, {indices!r}, **axis_kw) with {axis_kw = }"
)
try:
out = xp.take_along_axis(x, indices, **axis_kw)

ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape(
"take_along_axis",
out_shape=out.shape,
expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
kw=dict(
x=x,
indices=indices,
axis=axis,
),
)

# value test: notation is from `np.take_along_axis` docstring
Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
for ii in sh.ndindex(Ni):
for kk in sh.ndindex(Nk):
a_1d = x[ii + (slice(None),) + kk]
i_1d = indices[ii + (slice(None),) + kk]
o_1d = out[ii + (slice(None),) + kk]
for j in range(new_len):
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
# value test: notation is from `np.take_along_axis` docstring
Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
for ii in sh.ndindex(Ni):
for kk in sh.ndindex(Nk):
a_1d = x[ii + (slice(None),) + kk]
i_1d = indices[ii + (slice(None),) + kk]
o_1d = out[ii + (slice(None),) + kk]
for j in range(new_len):
assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
except Exception as exc:
ph.add_note(exc, repro_snippet)
raise
Loading