11import operator
22from builtins import all as all_
33
4- import numpy .testing
4+ from numpy .testing import assert_raises , suppress_warnings
55import numpy as np
66import pytest
77
2929
3030import array_api_strict
3131
32- def assert_raises (exception , func , msg = None ):
33- with numpy .testing .assert_raises (exception , msg = msg ):
34- func ()
35-
3632def test_validate_index ():
3733 # The indexing tests in the official array API test suite test that the
3834 # array object correctly handles the subset of indices that are required
@@ -94,7 +90,7 @@ def test_validate_index():
9490
9591def test_operators ():
9692 # For every operator, we test that it works for the required type
97- # combinations and assert_raises TypeError otherwise
93+ # combinations and raises TypeError otherwise
9894 binary_op_dtypes = {
9995 "__add__" : "numeric" ,
10096 "__and__" : "integer_or_boolean" ,
@@ -115,7 +111,6 @@ def test_operators():
115111 "__truediv__" : "floating" ,
116112 "__xor__" : "integer_or_boolean" ,
117113 }
118- comparison_ops = ["__eq__" , "__ne__" , "__le__" , "__ge__" , "__lt__" , "__gt__" ]
119114 # Recompute each time because of in-place ops
120115 def _array_vals ():
121116 for d in _integer_dtypes :
@@ -129,7 +124,7 @@ def _array_vals():
129124 BIG_INT = int (1e30 )
130125 for op , dtypes in binary_op_dtypes .items ():
131126 ops = [op ]
132- if op not in comparison_ops :
127+ if op not in [ "__eq__" , "__ne__" , "__le__" , "__ge__" , "__lt__" , "__gt__" ] :
133128 rop = "__r" + op [2 :]
134129 iop = "__i" + op [2 :]
135130 ops += [rop , iop ]
@@ -160,16 +155,16 @@ def _array_vals():
160155 or a .dtype in _complex_floating_dtypes and type (s ) in [complex , float , int ]
161156 )):
162157 if a .dtype in _integer_dtypes and s == BIG_INT :
163- assert_raises (OverflowError , lambda : getattr (a , _op )(s ), _op )
158+ assert_raises (OverflowError , lambda : getattr (a , _op )(s ))
164159 else :
165160 # Only test for no error
166- with numpy . testing . suppress_warnings () as sup :
161+ with suppress_warnings () as sup :
167162 # ignore warnings from pow(BIG_INT)
168163 sup .filter (RuntimeWarning ,
169164 "invalid value encountered in power" )
170165 getattr (a , _op )(s )
171166 else :
172- assert_raises (TypeError , lambda : getattr (a , _op )(s ), _op )
167+ assert_raises (TypeError , lambda : getattr (a , _op )(s ))
173168
174169 # Test array op array.
175170 for _op in ops :
@@ -178,25 +173,25 @@ def _array_vals():
178173 # See the promotion table in NEP 47 or the array
179174 # API spec page on type promotion. Mixed kind
180175 # promotion is not defined.
181- if (op not in comparison_ops and
182- (x .dtype == uint64 and y .dtype in [int8 , int16 , int32 , int64 ]
183- or y .dtype == uint64 and x .dtype in [int8 , int16 , int32 , int64 ]
184- or x .dtype in _integer_dtypes and y .dtype not in _integer_dtypes
185- or y .dtype in _integer_dtypes and x .dtype not in _integer_dtypes
186- or x .dtype in _boolean_dtypes and y .dtype not in _boolean_dtypes
187- or y .dtype in _boolean_dtypes and x .dtype not in _boolean_dtypes
188- or x .dtype in _floating_dtypes and y .dtype not in _floating_dtypes
189- or y .dtype in _floating_dtypes and x .dtype not in _floating_dtypes
190- )):
191- assert_raises (TypeError , lambda : getattr (x , _op )(y ), _op )
176+ if (x .dtype == uint64 and y .dtype in [int8 , int16 , int32 , int64 ]
177+ or y .dtype == uint64 and x .dtype in [int8 , int16 , int32 , int64 ]
178+ or x .dtype in _integer_dtypes and y .dtype not in _integer_dtypes
179+ or y .dtype in _integer_dtypes and x .dtype not in _integer_dtypes
180+ or x .dtype in _boolean_dtypes and y .dtype not in _boolean_dtypes
181+ or y .dtype in _boolean_dtypes and x .dtype not in _boolean_dtypes
182+ or x .dtype in _floating_dtypes and y .dtype not in _floating_dtypes
183+ or y .dtype in _floating_dtypes and x .dtype not in _floating_dtypes
184+ ):
185+ assert_raises (TypeError , lambda : getattr (x , _op )(y ))
192186 # Ensure in-place operators only promote to the same dtype as the left operand.
193187 elif (
194188 _op .startswith ("__i" )
195189 and result_type (x .dtype , y .dtype ) != x .dtype
196190 ):
197- assert_raises (TypeError , lambda : getattr (x , _op )(y ), _op )
191+ assert_raises (TypeError , lambda : getattr (x , _op )(y ))
198192 # Ensure only those dtypes that are required for every operator are allowed.
199- elif (dtypes == "all"
193+ elif (dtypes == "all" and (x .dtype in _boolean_dtypes and y .dtype in _boolean_dtypes
194+ or x .dtype in _numeric_dtypes and y .dtype in _numeric_dtypes )
200195 or (dtypes == "real numeric" and x .dtype in _real_numeric_dtypes and y .dtype in _real_numeric_dtypes )
201196 or (dtypes == "numeric" and x .dtype in _numeric_dtypes and y .dtype in _numeric_dtypes )
202197 or dtypes == "integer" and x .dtype in _integer_dtypes and y .dtype in _integer_dtypes
@@ -207,7 +202,7 @@ def _array_vals():
207202 ):
208203 getattr (x , _op )(y )
209204 else :
210- assert_raises (TypeError , lambda : getattr (x , _op )(y ), ( x , _op , y ) )
205+ assert_raises (TypeError , lambda : getattr (x , _op )(y ))
211206
212207 unary_op_dtypes = {
213208 "__abs__" : "numeric" ,
@@ -226,7 +221,7 @@ def _array_vals():
226221 # Only test for no error
227222 getattr (a , op )()
228223 else :
229- assert_raises (TypeError , lambda : getattr (a , op )(), _op )
224+ assert_raises (TypeError , lambda : getattr (a , op )())
230225
231226 # Finally, matmul() must be tested separately, because it works a bit
232227 # different from the other operations.
@@ -245,9 +240,9 @@ def _matmul_array_vals():
245240 or type (s ) == int and a .dtype in _integer_dtypes ):
246241 # Type promotion is valid, but @ is not allowed on 0-D
247242 # inputs, so the error is a ValueError
248- assert_raises (ValueError , lambda : getattr (a , _op )(s ), _op )
243+ assert_raises (ValueError , lambda : getattr (a , _op )(s ))
249244 else :
250- assert_raises (TypeError , lambda : getattr (a , _op )(s ), _op )
245+ assert_raises (TypeError , lambda : getattr (a , _op )(s ))
251246
252247 for x in _matmul_array_vals ():
253248 for y in _matmul_array_vals ():
@@ -361,17 +356,20 @@ def test_allow_newaxis():
361356
362357def test_disallow_flat_indexing_with_newaxis ():
363358 a = ones ((3 , 3 , 3 ))
364- assert_raises (IndexError , lambda : a [None , 0 , 0 ])
359+ with pytest .raises (IndexError ):
360+ a [None , 0 , 0 ]
365361
366362def test_disallow_mask_with_newaxis ():
367363 a = ones ((3 , 3 , 3 ))
368- assert_raises (IndexError , lambda : a [None , asarray (True )])
364+ with pytest .raises (IndexError ):
365+ a [None , asarray (True )]
369366
370367@pytest .mark .parametrize ("shape" , [(), (5 ,), (3 , 3 , 3 )])
371368@pytest .mark .parametrize ("index" , ["string" , False , True ])
372369def test_error_on_invalid_index (shape , index ):
373370 a = ones (shape )
374- assert_raises (IndexError , lambda : a [index ])
371+ with pytest .raises (IndexError ):
372+ a [index ]
375373
376374def test_mask_0d_array_without_errors ():
377375 a = ones (())
@@ -382,8 +380,10 @@ def test_mask_0d_array_without_errors():
382380)
383381def test_error_on_invalid_index_with_ellipsis (i ):
384382 a = ones ((3 , 3 , 3 ))
385- assert_raises (IndexError , lambda : a [..., i ])
386- assert_raises (IndexError , lambda : a [i , ...])
383+ with pytest .raises (IndexError ):
384+ a [..., i ]
385+ with pytest .raises (IndexError ):
386+ a [i , ...]
387387
388388def test_array_keys_use_private_array ():
389389 """
@@ -400,7 +400,8 @@ def test_array_keys_use_private_array():
400400
401401 a = ones ((0 ,), dtype = bool_ )
402402 key = ones ((0 , 0 ), dtype = bool_ )
403- assert_raises (IndexError , lambda : a [key ])
403+ with pytest .raises (IndexError ):
404+ a [key ]
404405
405406def test_array_namespace ():
406407 a = ones ((3 , 3 ))
@@ -421,16 +422,16 @@ def test_array_namespace():
421422 assert a .__array_namespace__ (api_version = "2021.12" ) is array_api_strict
422423 assert array_api_strict .__array_api_version__ == "2021.12"
423424
424- assert_raises (ValueError , lambda : a .__array_namespace__ (api_version = "2021.11" ))
425- assert_raises (ValueError , lambda : a .__array_namespace__ (api_version = "2024.12" ))
425+ pytest . raises (ValueError , lambda : a .__array_namespace__ (api_version = "2021.11" ))
426+ pytest . raises (ValueError , lambda : a .__array_namespace__ (api_version = "2024.12" ))
426427
427428def test_iter ():
428- assert_raises (TypeError , lambda : iter (asarray (3 )))
429+ pytest . raises (TypeError , lambda : iter (asarray (3 )))
429430 assert list (ones (3 )) == [asarray (1. ), asarray (1. ), asarray (1. )]
430431 assert all_ (isinstance (a , Array ) for a in iter (ones (3 )))
431432 assert all_ (a .shape == () for a in iter (ones (3 )))
432433 assert all_ (a .dtype == float64 for a in iter (ones (3 )))
433- assert_raises (TypeError , lambda : iter (ones ((3 , 3 ))))
434+ pytest . raises (TypeError , lambda : iter (ones ((3 , 3 ))))
434435
435436@pytest .mark .parametrize ("api_version" , ['2021.12' , '2022.12' , '2023.12' ])
436437def dlpack_2023_12 (api_version ):
@@ -446,17 +447,17 @@ def dlpack_2023_12(api_version):
446447
447448
448449 exception = NotImplementedError if api_version >= '2023.12' else ValueError
449- assert_raises (exception , lambda :
450+ pytest . raises (exception , lambda :
450451 a .__dlpack__ (dl_device = CPU_DEVICE ))
451- assert_raises (exception , lambda :
452+ pytest . raises (exception , lambda :
452453 a .__dlpack__ (dl_device = None ))
453- assert_raises (exception , lambda :
454+ pytest . raises (exception , lambda :
454455 a .__dlpack__ (max_version = (1 , 0 )))
455- assert_raises (exception , lambda :
456+ pytest . raises (exception , lambda :
456457 a .__dlpack__ (max_version = None ))
457- assert_raises (exception , lambda :
458+ pytest . raises (exception , lambda :
458459 a .__dlpack__ (copy = False ))
459- assert_raises (exception , lambda :
460+ pytest . raises (exception , lambda :
460461 a .__dlpack__ (copy = True ))
461- assert_raises (exception , lambda :
462+ pytest . raises (exception , lambda :
462463 a .__dlpack__ (copy = None ))
0 commit comments