33from typing import Any , Dict , Optional , Sequence , Tuple , Union
44
55from . import _array_module as xp
6- from . import array_helpers as ah
76from . import dtype_helpers as dh
87from . import shape_helpers as sh
98from . import stubs
@@ -88,6 +87,40 @@ def assert_dtype(
8887 * ,
8988 repr_name : str = "out.dtype" ,
9089):
90+ """
91+ Assert the output dtype is as expected.
92+
93+ If expected=None, we infer the expected dtype as in_dtype, to test
94+ out_dtype, e.g.
95+
96+ >>> x = xp.arange(5, dtype=xp.uint8)
97+ >>> out = xp.abs(x)
98+ >>> assert_dtype('abs', x.dtype, out.dtype)
99+
100+ is equivalent to
101+
102+ >>> assert out.dtype == xp.uint8
103+
104+ Or for multiple input dtypes, the expected dtype is inferred from their
105+ resulting type promotion, e.g.
106+
107+ >>> x1 = xp.arange(5, dtype=xp.uint8)
108+ >>> x2 = xp.arange(5, dtype=xp.uint16)
109+ >>> out = xp.add(x1, x2)
110+ >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
111+
112+ is equivalent to
113+
114+ >>> assert out.dtype == xp.uint16
115+
116+ We can also specify the expected dtype ourselves, e.g.
117+
118+ >>> x = xp.arange(5, dtype=xp.int8)
119+ >>> out = xp.sum(x)
120+ >>> default_int = xp.asarray(0).dtype
121+ >>> assert_dtype('sum', x, out.dtype, default_int)
122+
123+ """
91124 in_dtypes = in_dtype if isinstance (in_dtype , Sequence ) else [in_dtype ]
92125 f_in_dtypes = dh .fmt_types (tuple (in_dtypes ))
93126 f_out_dtype = dh .dtype_to_name [out_dtype ]
@@ -102,6 +135,14 @@ def assert_dtype(
102135
103136
104137def assert_kw_dtype (func_name : str , kw_dtype : DataType , out_dtype : DataType ):
138+ """
139+ Assert the output dtype is the passed keyword dtype, e.g.
140+
141+ >>> kw = {'dtype': xp.uint8}
142+ >>> out = xp.ones(5, **kw)
143+ >>> assert_kw_dtype('ones', kw['dtype'], out.dtype)
144+
145+ """
105146 f_kw_dtype = dh .dtype_to_name [kw_dtype ]
106147 f_out_dtype = dh .dtype_to_name [out_dtype ]
107148 msg = (
@@ -111,33 +152,54 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
111152 assert out_dtype == kw_dtype , msg
112153
113154
114- def assert_default_float (func_name : str , dtype : DataType ):
115- f_dtype = dh .dtype_to_name [dtype ]
155+ def assert_default_float (func_name : str , out_dtype : DataType ):
156+ """
157+ Assert the output dtype is the default float, e.g.
158+
159+ >>> out = xp.ones(5)
160+ >>> assert_default_float('ones', out.dtype)
161+
162+ """
163+ f_dtype = dh .dtype_to_name [out_dtype ]
116164 f_default = dh .dtype_to_name [dh .default_float ]
117165 msg = (
118166 f"out.dtype={ f_dtype } , should be default "
119167 f"floating-point dtype { f_default } [{ func_name } ()]"
120168 )
121- assert dtype == dh .default_float , msg
169+ assert out_dtype == dh .default_float , msg
122170
123171
124- def assert_default_int (func_name : str , dtype : DataType ):
125- f_dtype = dh .dtype_to_name [dtype ]
172+ def assert_default_int (func_name : str , out_dtype : DataType ):
173+ """
174+ Assert the output dtype is the default int, e.g.
175+
176+ >>> out = xp.full(5, 42)
177+ >>> assert_default_int('full', out.dtype)
178+
179+ """
180+ f_dtype = dh .dtype_to_name [out_dtype ]
126181 f_default = dh .dtype_to_name [dh .default_int ]
127182 msg = (
128183 f"out.dtype={ f_dtype } , should be default "
129184 f"integer dtype { f_default } [{ func_name } ()]"
130185 )
131- assert dtype == dh .default_int , msg
186+ assert out_dtype == dh .default_int , msg
187+
188+
189+ def assert_default_index (func_name : str , out_dtype : DataType , repr_name = "out.dtype" ):
190+ """
191+ Assert the output dtype is the default index dtype, e.g.
132192
193+ >>> out = xp.argmax(xp.arange(5))
194+ >>> assert_default_int('argmax', out.dtype)
133195
134- def assert_default_index ( func_name : str , dtype : DataType , repr_name = "out.dtype" ):
135- f_dtype = dh .dtype_to_name [dtype ]
196+ """
197+ f_dtype = dh .dtype_to_name [out_dtype ]
136198 msg = (
137199 f"{ repr_name } ={ f_dtype } , should be the default index dtype, "
138200 f"which is either int32 or int64 [{ func_name } ()]"
139201 )
140- assert dtype in (xp .int32 , xp .int64 ), msg
202+ assert out_dtype in (xp .int32 , xp .int64 ), msg
141203
142204
143205def assert_shape (
@@ -148,6 +210,13 @@ def assert_shape(
148210 repr_name = "out.shape" ,
149211 ** kw ,
150212):
213+ """
214+ Assert the output shape is as expected, e.g.
215+
216+ >>> out = xp.ones((3, 3, 3))
217+ >>> assert_shape('ones', out.shape, (3, 3, 3))
218+
219+ """
151220 if isinstance (out_shape , int ):
152221 out_shape = (out_shape ,)
153222 if isinstance (expected , int ):
@@ -168,6 +237,20 @@ def assert_result_shape(
168237 repr_name = "out.shape" ,
169238 ** kw ,
170239):
240+ """
241+ Assert the output shape is as expected.
242+
243+ If expected=None, we infer the expected shape as the result of broadcasting
244+ in_shapes, to test against out_shape, e.g.
245+
246+ >>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
247+ >>> assert_shape('add', [(3, 1), (1, 3)], out.shape)
248+
249+ is equivalent to
250+
251+ >>> assert out.shape == (3, 3)
252+
253+ """
171254 if expected is None :
172255 expected = sh .broadcast_shapes (* in_shapes )
173256 f_in_shapes = " . " .join (str (s ) for s in in_shapes )
@@ -180,13 +263,28 @@ def assert_result_shape(
180263
181264def assert_keepdimable_shape (
182265 func_name : str ,
183- out_shape : Shape ,
184266 in_shape : Shape ,
267+ out_shape : Shape ,
185268 axes : Tuple [int , ...],
186269 keepdims : bool ,
187270 / ,
188271 ** kw ,
189272):
273+ """
274+ Assert the output shape from a keepdimable function is as expected, e.g.
275+
276+ >>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
277+ >>> out1 = xp.max(x, keepdims=False)
278+ >>> out2 = xp.max(x, keepdims=True)
279+ >>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
280+ >>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)
281+
282+ is equivalent to
283+
284+ >>> assert out1.shape == ()
285+ >>> assert out2.shape == (1, 1)
286+
287+ """
190288 if keepdims :
191289 shape = tuple (1 if axis in axes else side for axis , side in enumerate (in_shape ))
192290 else :
@@ -197,6 +295,19 @@ def assert_keepdimable_shape(
197295def assert_0d_equals (
198296 func_name : str , x_repr : str , x_val : Array , out_repr : str , out_val : Array , ** kw
199297):
298+ """
299+ Assert a 0d array is as expected, e.g.
300+
301+ >>> x = xp.asarray([0, 1, 2])
302+ >>> res = xp.asarray(x, copy=True)
303+ >>> res[0] = 42
304+ >>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])
305+
306+ is equivalent to
307+
308+ >>> assert res[0] == x[0]
309+
310+ """
200311 msg = (
201312 f"{ out_repr } ={ out_val } , but should be { x_repr } ={ x_val } "
202313 f"[{ func_name } ({ fmt_kw (kw )} )]"
@@ -217,9 +328,21 @@ def assert_scalar_equals(
217328 repr_name : str = "out" ,
218329 ** kw ,
219330):
331+ """
332+ Assert a 0d array, convered to a scalar, is as expected, e.g.
333+
334+ >>> x = xp.ones(5, dtype=xp.uint8)
335+ >>> out = xp.sum(x)
336+ >>> assert_scalar_equals('sum', int, (), int(out), 5)
337+
338+ is equivalent to
339+
340+ >>> assert int(out) == 5
341+
342+ """
220343 repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
221344 f_func = f"{ func_name } ({ fmt_kw (kw )} )"
222- if type_ is bool or type_ is int :
345+ if type_ in [ bool , int ] :
223346 msg = f"{ repr_name } ={ out } , but should be { expected } [{ f_func } ]"
224347 assert out == expected , msg
225348 elif math .isnan (expected ):
@@ -233,14 +356,37 @@ def assert_scalar_equals(
233356def assert_fill (
234357 func_name : str , fill_value : Scalar , dtype : DataType , out : Array , / , ** kw
235358):
359+ """
360+ Assert all elements of an array is as expected, e.g.
361+
362+ >>> out = xp.full(5, 42, dtype=xp.uint8)
363+ >>> assert_fill('full', 42, xp.uint8, out, 5)
364+
365+ is equivalent to
366+
367+ >>> assert xp.all(out == 42)
368+
369+ """
236370 msg = f"out not filled with { fill_value } [{ func_name } ({ fmt_kw (kw )} )]\n { out = } "
237371 if math .isnan (fill_value ):
238- assert ah .all (ah .isnan (out )), msg
372+ assert xp .all (xp .isnan (out )), msg
239373 else :
240- assert ah .all (ah .equal (out , ah .asarray (fill_value , dtype = dtype ))), msg
374+ assert xp .all (xp .equal (out , xp .asarray (fill_value , dtype = dtype ))), msg
241375
242376
243377def assert_array (func_name : str , out : Array , expected : Array , / , ** kw ):
378+ """
379+ Assert array is (strictly) as expected, e.g.
380+
381+ >>> x = xp.arange(5)
382+ >>> out = xp.asarray(x)
383+ >>> assert_array('asarray', out, x)
384+
385+ is equivalent to
386+
387+ >>> assert xp.all(out == x)
388+
389+ """
244390 assert_dtype (func_name , out .dtype , expected .dtype )
245391 assert_shape (func_name , out .shape , expected .shape , ** kw )
246392 f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
0 commit comments