11import operator
22from builtins import all as all_
33
4- from numpy .testing import assert_raises , suppress_warnings
4+ import numpy .testing
55import numpy as np
66import pytest
77
2929
3030import array_api_strict
3131
32+ def assert_raises (exception , func , msg = None ):
33+ with numpy .testing .assert_raises (exception , msg = msg ):
34+ func ()
35+
3236def test_validate_index ():
3337 # The indexing tests in the official array API test suite test that the
3438 # array object correctly handles the subset of indices that are required
@@ -90,7 +94,7 @@ def test_validate_index():
9094
9195def test_operators ():
9296 # For every operator, we test that it works for the required type
93- # combinations and raises TypeError otherwise
97+ # combinations and assert_raises TypeError otherwise
9498 binary_op_dtypes = {
9599 "__add__" : "numeric" ,
96100 "__and__" : "integer_or_boolean" ,
@@ -111,6 +115,7 @@ def test_operators():
111115 "__truediv__" : "floating" ,
112116 "__xor__" : "integer_or_boolean" ,
113117 }
118+ comparison_ops = ["__eq__" , "__ne__" , "__le__" , "__ge__" , "__lt__" , "__gt__" ]
114119 # Recompute each time because of in-place ops
115120 def _array_vals ():
116121 for d in _integer_dtypes :
@@ -124,7 +129,7 @@ def _array_vals():
124129 BIG_INT = int (1e30 )
125130 for op , dtypes in binary_op_dtypes .items ():
126131 ops = [op ]
127- if op not in [ "__eq__" , "__ne__" , "__le__" , "__ge__" , "__lt__" , "__gt__" ] :
132+ if op not in comparison_ops :
128133 rop = "__r" + op [2 :]
129134 iop = "__i" + op [2 :]
130135 ops += [rop , iop ]
@@ -155,16 +160,16 @@ def _array_vals():
155160 or a .dtype in _complex_floating_dtypes and type (s ) in [complex , float , int ]
156161 )):
157162 if a .dtype in _integer_dtypes and s == BIG_INT :
158- assert_raises (OverflowError , lambda : getattr (a , _op )(s ))
163+ assert_raises (OverflowError , lambda : getattr (a , _op )(s ), _op )
159164 else :
160165 # Only test for no error
161- with suppress_warnings () as sup :
166+ with numpy . testing . suppress_warnings () as sup :
162167 # ignore warnings from pow(BIG_INT)
163168 sup .filter (RuntimeWarning ,
164169 "invalid value encountered in power" )
165170 getattr (a , _op )(s )
166171 else :
167- assert_raises (TypeError , lambda : getattr (a , _op )(s ))
172+ assert_raises (TypeError , lambda : getattr (a , _op )(s ), _op )
168173
169174 # Test array op array.
170175 for _op in ops :
@@ -173,25 +178,25 @@ def _array_vals():
173178 # See the promotion table in NEP 47 or the array
174179 # API spec page on type promotion. Mixed kind
175180 # promotion is not defined.
176- if (x .dtype == uint64 and y .dtype in [int8 , int16 , int32 , int64 ]
177- or y .dtype == uint64 and x .dtype in [int8 , int16 , int32 , int64 ]
178- or x .dtype in _integer_dtypes and y .dtype not in _integer_dtypes
179- or y .dtype in _integer_dtypes and x .dtype not in _integer_dtypes
180- or x .dtype in _boolean_dtypes and y .dtype not in _boolean_dtypes
181- or y .dtype in _boolean_dtypes and x .dtype not in _boolean_dtypes
182- or x .dtype in _floating_dtypes and y .dtype not in _floating_dtypes
183- or y .dtype in _floating_dtypes and x .dtype not in _floating_dtypes
184- ):
185- assert_raises (TypeError , lambda : getattr (x , _op )(y ))
181+ if (op not in comparison_ops and
182+ (x .dtype == uint64 and y .dtype in [int8 , int16 , int32 , int64 ]
183+ or y .dtype == uint64 and x .dtype in [int8 , int16 , int32 , int64 ]
184+ or x .dtype in _integer_dtypes and y .dtype not in _integer_dtypes
185+ or y .dtype in _integer_dtypes and x .dtype not in _integer_dtypes
186+ or x .dtype in _boolean_dtypes and y .dtype not in _boolean_dtypes
187+ or y .dtype in _boolean_dtypes and x .dtype not in _boolean_dtypes
188+ or x .dtype in _floating_dtypes and y .dtype not in _floating_dtypes
189+ or y .dtype in _floating_dtypes and x .dtype not in _floating_dtypes
190+ )):
191+ assert_raises (TypeError , lambda : getattr (x , _op )(y ), _op )
186192 # Ensure in-place operators only promote to the same dtype as the left operand.
187193 elif (
188194 _op .startswith ("__i" )
189195 and result_type (x .dtype , y .dtype ) != x .dtype
190196 ):
191- assert_raises (TypeError , lambda : getattr (x , _op )(y ))
197+ assert_raises (TypeError , lambda : getattr (x , _op )(y ), _op )
192198 # Ensure only those dtypes that are required for every operator are allowed.
193- elif (dtypes == "all" and (x .dtype in _boolean_dtypes and y .dtype in _boolean_dtypes
194- or x .dtype in _numeric_dtypes and y .dtype in _numeric_dtypes )
199+ elif (dtypes == "all"
195200 or (dtypes == "real numeric" and x .dtype in _real_numeric_dtypes and y .dtype in _real_numeric_dtypes )
196201 or (dtypes == "numeric" and x .dtype in _numeric_dtypes and y .dtype in _numeric_dtypes )
197202 or dtypes == "integer" and x .dtype in _integer_dtypes and y .dtype in _integer_dtypes
@@ -202,7 +207,7 @@ def _array_vals():
202207 ):
203208 getattr (x , _op )(y )
204209 else :
205- assert_raises (TypeError , lambda : getattr (x , _op )(y ))
210+ assert_raises (TypeError , lambda : getattr (x , _op )(y ), ( x , _op , y ) )
206211
207212 unary_op_dtypes = {
208213 "__abs__" : "numeric" ,
@@ -221,7 +226,7 @@ def _array_vals():
221226 # Only test for no error
222227 getattr (a , op )()
223228 else :
224- assert_raises (TypeError , lambda : getattr (a , op )())
229+ assert_raises (TypeError , lambda : getattr (a , op )(), _op )
225230
226231 # Finally, matmul() must be tested separately, because it works a bit
227232 # different from the other operations.
@@ -240,9 +245,9 @@ def _matmul_array_vals():
240245 or type (s ) == int and a .dtype in _integer_dtypes ):
241246 # Type promotion is valid, but @ is not allowed on 0-D
242247 # inputs, so the error is a ValueError
243- assert_raises (ValueError , lambda : getattr (a , _op )(s ))
248+ assert_raises (ValueError , lambda : getattr (a , _op )(s ), _op )
244249 else :
245- assert_raises (TypeError , lambda : getattr (a , _op )(s ))
250+ assert_raises (TypeError , lambda : getattr (a , _op )(s ), _op )
246251
247252 for x in _matmul_array_vals ():
248253 for y in _matmul_array_vals ():
@@ -356,20 +361,17 @@ def test_allow_newaxis():
356361
357362def test_disallow_flat_indexing_with_newaxis ():
358363 a = ones ((3 , 3 , 3 ))
359- with pytest .raises (IndexError ):
360- a [None , 0 , 0 ]
364+ assert_raises (IndexError , lambda : a [None , 0 , 0 ])
361365
362366def test_disallow_mask_with_newaxis ():
363367 a = ones ((3 , 3 , 3 ))
364- with pytest .raises (IndexError ):
365- a [None , asarray (True )]
368+ assert_raises (IndexError , lambda : a [None , asarray (True )])
366369
367370@pytest .mark .parametrize ("shape" , [(), (5 ,), (3 , 3 , 3 )])
368371@pytest .mark .parametrize ("index" , ["string" , False , True ])
369372def test_error_on_invalid_index (shape , index ):
370373 a = ones (shape )
371- with pytest .raises (IndexError ):
372- a [index ]
374+ assert_raises (IndexError , lambda : a [index ])
373375
374376def test_mask_0d_array_without_errors ():
375377 a = ones (())
@@ -380,10 +382,8 @@ def test_mask_0d_array_without_errors():
380382)
381383def test_error_on_invalid_index_with_ellipsis (i ):
382384 a = ones ((3 , 3 , 3 ))
383- with pytest .raises (IndexError ):
384- a [..., i ]
385- with pytest .raises (IndexError ):
386- a [i , ...]
385+ assert_raises (IndexError , lambda : a [..., i ])
386+ assert_raises (IndexError , lambda : a [i , ...])
387387
388388def test_array_keys_use_private_array ():
389389 """
@@ -400,8 +400,7 @@ def test_array_keys_use_private_array():
400400
401401 a = ones ((0 ,), dtype = bool_ )
402402 key = ones ((0 , 0 ), dtype = bool_ )
403- with pytest .raises (IndexError ):
404- a [key ]
403+ assert_raises (IndexError , lambda : a [key ])
405404
406405def test_array_namespace ():
407406 a = ones ((3 , 3 ))
@@ -422,16 +421,16 @@ def test_array_namespace():
422421 assert a .__array_namespace__ (api_version = "2021.12" ) is array_api_strict
423422 assert array_api_strict .__array_api_version__ == "2021.12"
424423
425- pytest . raises (ValueError , lambda : a .__array_namespace__ (api_version = "2021.11" ))
426- pytest . raises (ValueError , lambda : a .__array_namespace__ (api_version = "2024.12" ))
424+ assert_raises (ValueError , lambda : a .__array_namespace__ (api_version = "2021.11" ))
425+ assert_raises (ValueError , lambda : a .__array_namespace__ (api_version = "2024.12" ))
427426
428427def test_iter ():
429- pytest . raises (TypeError , lambda : iter (asarray (3 )))
428+ assert_raises (TypeError , lambda : iter (asarray (3 )))
430429 assert list (ones (3 )) == [asarray (1. ), asarray (1. ), asarray (1. )]
431430 assert all_ (isinstance (a , Array ) for a in iter (ones (3 )))
432431 assert all_ (a .shape == () for a in iter (ones (3 )))
433432 assert all_ (a .dtype == float64 for a in iter (ones (3 )))
434- pytest . raises (TypeError , lambda : iter (ones ((3 , 3 ))))
433+ assert_raises (TypeError , lambda : iter (ones ((3 , 3 ))))
435434
436435@pytest .mark .parametrize ("api_version" , ['2021.12' , '2022.12' , '2023.12' ])
437436def dlpack_2023_12 (api_version ):
@@ -447,17 +446,17 @@ def dlpack_2023_12(api_version):
447446
448447
449448 exception = NotImplementedError if api_version >= '2023.12' else ValueError
450- pytest . raises (exception , lambda :
449+ assert_raises (exception , lambda :
451450 a .__dlpack__ (dl_device = CPU_DEVICE ))
452- pytest . raises (exception , lambda :
451+ assert_raises (exception , lambda :
453452 a .__dlpack__ (dl_device = None ))
454- pytest . raises (exception , lambda :
453+ assert_raises (exception , lambda :
455454 a .__dlpack__ (max_version = (1 , 0 )))
456- pytest . raises (exception , lambda :
455+ assert_raises (exception , lambda :
457456 a .__dlpack__ (max_version = None ))
458- pytest . raises (exception , lambda :
457+ assert_raises (exception , lambda :
459458 a .__dlpack__ (copy = False ))
460- pytest . raises (exception , lambda :
459+ assert_raises (exception , lambda :
461460 a .__dlpack__ (copy = True ))
462- pytest . raises (exception , lambda :
461+ assert_raises (exception , lambda :
463462 a .__dlpack__ (copy = None ))
0 commit comments