@@ -35,39 +35,43 @@ def test_take(x, data):
3535 indices = xp .asarray (_indices , dtype = dh .default_int )
3636 note (f"{ indices = } " )
3737
38- out = xp .take (x , indices , ** kw )
39-
40- ph .assert_dtype ("take" , in_dtype = x .dtype , out_dtype = out .dtype )
41- ph .assert_shape (
42- "take" ,
43- out_shape = out .shape ,
44- expected = x .shape [:n_axis ] + (len (_indices ),) + x .shape [n_axis + 1 :],
45- kw = dict (
46- x = x ,
47- indices = indices ,
48- axis = axis ,
49- ),
50- )
51- out_indices = sh .ndindex (out .shape )
52- axis_indices = list (sh .axis_ndindex (x .shape , n_axis ))
53- for axis_idx in axis_indices :
54- f_axis_idx = sh .fmt_idx ("x" , axis_idx )
55- for i in _indices :
56- f_take_idx = sh .fmt_idx (f_axis_idx , i )
57- indexed_x = x [axis_idx ][i , ...]
58- for at_idx in sh .ndindex (indexed_x .shape ):
59- out_idx = next (out_indices )
60- ph .assert_0d_equals (
61- "take" ,
62- x_repr = sh .fmt_idx (f_take_idx , at_idx ),
63- x_val = indexed_x [at_idx ],
64- out_repr = sh .fmt_idx ("out" , out_idx ),
65- out_val = out [out_idx ],
66- )
67- # sanity check
68- with pytest .raises (StopIteration ):
69- next (out_indices )
38+ repro_snippet = ph .format_snippet (f"xp.take({ x !r} , { indices !r} , **kw) with { kw = } " )
39+ try :
40+ out = xp .take (x , indices , ** kw )
7041
42+ ph .assert_dtype ("take" , in_dtype = x .dtype , out_dtype = out .dtype )
43+ ph .assert_shape (
44+ "take" ,
45+ out_shape = out .shape ,
46+ expected = x .shape [:n_axis ] + (len (_indices ),) + x .shape [n_axis + 1 :],
47+ kw = dict (
48+ x = x ,
49+ indices = indices ,
50+ axis = axis ,
51+ ),
52+ )
53+ out_indices = sh .ndindex (out .shape )
54+ axis_indices = list (sh .axis_ndindex (x .shape , n_axis ))
55+ for axis_idx in axis_indices :
56+ f_axis_idx = sh .fmt_idx ("x" , axis_idx )
57+ for i in _indices :
58+ f_take_idx = sh .fmt_idx (f_axis_idx , i )
59+ indexed_x = x [axis_idx ][i , ...]
60+ for at_idx in sh .ndindex (indexed_x .shape ):
61+ out_idx = next (out_indices )
62+ ph .assert_0d_equals (
63+ "take" ,
64+ x_repr = sh .fmt_idx (f_take_idx , at_idx ),
65+ x_val = indexed_x [at_idx ],
66+ out_repr = sh .fmt_idx ("out" , out_idx ),
67+ out_val = out [out_idx ],
68+ )
69+ # sanity check
70+ with pytest .raises (StopIteration ):
71+ next (out_indices )
72+ except Exception as exc :
73+ exc .add_note (repro_snippet )
74+ raise
7175
7276@pytest .mark .unvectorized
7377@pytest .mark .min_version ("2024.12" )
@@ -103,26 +107,33 @@ def test_take_along_axis(x, data):
103107 )
104108 note (f"{ indices = } { idx_shape = } " )
105109
106- out = xp .take_along_axis (x , indices , ** axis_kw )
107-
108- ph .assert_dtype ("take_along_axis" , in_dtype = x .dtype , out_dtype = out .dtype )
109- ph .assert_shape (
110- "take_along_axis" ,
111- out_shape = out .shape ,
112- expected = x .shape [:n_axis ] + (new_len ,) + x .shape [n_axis + 1 :],
113- kw = dict (
114- x = x ,
115- indices = indices ,
116- axis = axis ,
117- ),
110+ repro_snippet = ph .format_snippet (
111+ f"xp.take_along_axis({ x !r} , { indices !r} , **axis_kw) with { axis_kw = } "
118112 )
113+ try :
114+ out = xp .take_along_axis (x , indices , ** axis_kw )
115+
116+ ph .assert_dtype ("take_along_axis" , in_dtype = x .dtype , out_dtype = out .dtype )
117+ ph .assert_shape (
118+ "take_along_axis" ,
119+ out_shape = out .shape ,
120+ expected = x .shape [:n_axis ] + (new_len ,) + x .shape [n_axis + 1 :],
121+ kw = dict (
122+ x = x ,
123+ indices = indices ,
124+ axis = axis ,
125+ ),
126+ )
119127
120- # value test: notation is from `np.take_along_axis` docstring
121- Ni , Nk = x .shape [:n_axis ], x .shape [n_axis + 1 :]
122- for ii in sh .ndindex (Ni ):
123- for kk in sh .ndindex (Nk ):
124- a_1d = x [ii + (slice (None ),) + kk ]
125- i_1d = indices [ii + (slice (None ),) + kk ]
126- o_1d = out [ii + (slice (None ),) + kk ]
127- for j in range (new_len ):
128- assert o_1d [j ] == a_1d [i_1d [j ]], f'{ ii = } , { kk = } , { j = } '
128+ # value test: notation is from `np.take_along_axis` docstring
129+ Ni , Nk = x .shape [:n_axis ], x .shape [n_axis + 1 :]
130+ for ii in sh .ndindex (Ni ):
131+ for kk in sh .ndindex (Nk ):
132+ a_1d = x [ii + (slice (None ),) + kk ]
133+ i_1d = indices [ii + (slice (None ),) + kk ]
134+ o_1d = out [ii + (slice (None ),) + kk ]
135+ for j in range (new_len ):
136+ assert o_1d [j ] == a_1d [i_1d [j ]], f'{ ii = } , { kk = } , { j = } '
137+ except Exception as exc :
138+ exc .add_note (repro_snippet )
139+ raise
0 commit comments