|
3 | 3 | import array_api_strict as xp |
4 | 4 |
|
5 | 5 | from array_api_strict import ArrayAPIStrictFlags |
6 | | -from array_api_strict._flags import draft_version |
7 | 6 |
|
8 | 7 |
|
9 | 8 | def test_where_with_scalars(): |
10 | 9 | x = xp.asarray([1, 2, 3, 1]) |
11 | 10 |
|
12 | 11 | # Versions up to and including 2023.12 don't support scalar arguments |
13 | | - with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): |
14 | | - xp.where(x == 1, 42, 44) |
| 12 | + with ArrayAPIStrictFlags(api_version='2023.12'): |
| 13 | + with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): |
| 14 | + xp.where(x == 1, 42, 44) |
15 | 15 |
|
16 | 16 | # Versions after 2023.12 support scalar arguments |
17 | | - with (pytest.warns( |
18 | | - UserWarning, |
19 | | - match="The 2024.12 version of the array API specification is in draft status" |
20 | | - ), |
21 | | - ArrayAPIStrictFlags(api_version=draft_version), |
22 | | - ): |
23 | | - x_where = xp.where(x == 1, xp.asarray(42), 44) |
24 | | - |
25 | | - expected = xp.asarray([42, 44, 44, 42]) |
26 | | - assert xp.all(x_where == expected) |
27 | | - |
28 | | - # The spec does not allow both x1 and x2 to be scalars |
29 | | - with pytest.raises(ValueError, match="One of"): |
30 | | - xp.where(x == 1, 42, 44) |
| 17 | + x_where = xp.where(x == 1, xp.asarray(42), 44) |
| 18 | + |
| 19 | + expected = xp.asarray([42, 44, 44, 42]) |
| 20 | + assert xp.all(x_where == expected) |
| 21 | + |
| 22 | + # The spec does not allow both x1 and x2 to be scalars |
| 23 | + with pytest.raises(ValueError, match="One of"): |
| 24 | + xp.where(x == 1, 42, 44) |
0 commit comments