1717)
1818def test_take (x , data ):
1919 # TODO:
20- # * negative axis
2120 # * negative indices
2221 # * different dtypes for indices
23- axis = data .draw (st .integers (0 , max (x .ndim - 1 , 0 )), label = "axis" )
22+
23+ # axis is optional but only if x.ndim == 1
24+ _axis_st = st .integers (- x .ndim , max (x .ndim - 1 , 0 ))
25+ if x .ndim == 1 :
26+ kw = data .draw (hh .kwargs (axis = _axis_st ))
27+ else :
28+ kw = {"axis" : data .draw (_axis_st )}
29+ axis = kw .get ("axis" , 0 )
2430 _indices = data .draw (
2531 st .lists (st .integers (0 , x .shape [axis ] - 1 ), min_size = 1 , unique = True ),
2632 label = "_indices" ,
2733 )
34+ n_axis = axis if axis >= 0 else x .ndim + axis
2835 indices = xp .asarray (_indices , dtype = dh .default_int )
2936 note (f"{ indices = } " )
3037
31- out = xp .take (x , indices , axis = axis )
38+ out = xp .take (x , indices , ** kw )
3239
3340 ph .assert_dtype ("take" , in_dtype = x .dtype , out_dtype = out .dtype )
3441 ph .assert_shape (
3542 "take" ,
3643 out_shape = out .shape ,
37- expected = x .shape [:axis ] + (len (_indices ),) + x .shape [axis + 1 :],
44+ expected = x .shape [:n_axis ] + (len (_indices ),) + x .shape [n_axis + 1 :],
3845 kw = dict (
3946 x = x ,
4047 indices = indices ,
4148 axis = axis ,
4249 ),
4350 )
4451 out_indices = sh .ndindex (out .shape )
45- axis_indices = list (sh .axis_ndindex (x .shape , axis ))
52+ axis_indices = list (sh .axis_ndindex (x .shape , n_axis ))
4653 for axis_idx in axis_indices :
4754 f_axis_idx = sh .fmt_idx ("x" , axis_idx )
4855 for i in _indices :
@@ -62,7 +69,6 @@ def test_take(x, data):
6269 next (out_indices )
6370
6471
65-
6672@pytest .mark .unvectorized
6773@pytest .mark .min_version ("2024.12" )
6874@given (
0 commit comments