1717)
1818def test_take (x , data ):
1919 # TODO:
20- # * negative axis
2120 # * negative indices
2221 # * different dtypes for indices
2322
2423 # axis is optional but only if x.ndim == 1
25- _axis_st = st .integers (0 , max (x .ndim - 1 , 0 ))
24+ _axis_st = st .integers (- x . ndim , max (x .ndim - 1 , 0 ))
2625 if x .ndim == 1 :
2726 kw = data .draw (hh .kwargs (axis = _axis_st ))
2827 else :
@@ -32,6 +31,7 @@ def test_take(x, data):
3231 st .lists (st .integers (0 , x .shape [axis ] - 1 ), min_size = 1 , unique = True ),
3332 label = "_indices" ,
3433 )
34+ n_axis = axis if axis >= 0 else x .ndim + axis
3535 indices = xp .asarray (_indices , dtype = dh .default_int )
3636 note (f"{ indices = } " )
3737
@@ -41,15 +41,15 @@ def test_take(x, data):
4141 ph .assert_shape (
4242 "take" ,
4343 out_shape = out .shape ,
44- expected = x .shape [:axis ] + (len (_indices ),) + x .shape [axis + 1 :],
44+ expected = x .shape [:n_axis ] + (len (_indices ),) + x .shape [n_axis + 1 :],
4545 kw = dict (
4646 x = x ,
4747 indices = indices ,
4848 axis = axis ,
4949 ),
5050 )
5151 out_indices = sh .ndindex (out .shape )
52- axis_indices = list (sh .axis_ndindex (x .shape , axis ))
52+ axis_indices = list (sh .axis_ndindex (x .shape , n_axis ))
5353 for axis_idx in axis_indices :
5454 f_axis_idx = sh .fmt_idx ("x" , axis_idx )
5555 for i in _indices :
@@ -69,7 +69,6 @@ def test_take(x, data):
6969 next (out_indices )
7070
7171
72-
7372@pytest .mark .unvectorized
7473@pytest .mark .min_version ("2024.12" )
7574@given (
0 commit comments