@@ -448,7 +448,7 @@ def test_iter():
448448 pytest .raises (TypeError , lambda : iter (ones ((3 , 3 ))))
449449
450450@pytest .mark .parametrize ("api_version" , ['2021.12' , '2022.12' , '2023.12' ])
451- def dlpack_2023_12 (api_version ):
451+ def test_dlpack_2023_12 (api_version ):
452452 if api_version == '2021.12' :
453453 with pytest .warns (UserWarning ):
454454 set_array_api_strict_flags (api_version = api_version )
@@ -459,25 +459,35 @@ def dlpack_2023_12(api_version):
459459 # Never an error
460460 a .__dlpack__ ()
461461
462-
463- if np .__version__ < '2.1' :
464- exception = NotImplementedError if api_version >= '2023.12' else ValueError
465- pytest .raises (exception , lambda :
466- a .__dlpack__ (dl_device = CPU_DEVICE ))
467- pytest .raises (exception , lambda :
462+ if api_version < '2023.12' :
463+ pytest .raises (ValueError , lambda :
464+ a .__dlpack__ (dl_device = a .__dlpack_device__ ()))
465+ pytest .raises (ValueError , lambda :
468466 a .__dlpack__ (dl_device = None ))
469- pytest .raises (exception , lambda :
467+ pytest .raises (ValueError , lambda :
470468 a .__dlpack__ (max_version = (1 , 0 )))
471- pytest .raises (exception , lambda :
469+ pytest .raises (ValueError , lambda :
472470 a .__dlpack__ (max_version = None ))
473- pytest .raises (exception , lambda :
474- a .__dlpack__ (copy = False ))
475- pytest .raises (exception , lambda :
476- a .__dlpack__ (copy = True ))
477- pytest .raises (exception , lambda :
478- a .__dlpack__ (copy = None ))
471+ pytest .raises (ValueError , lambda :
472+ a .__dlpack__ (copy = False ))
473+ pytest .raises (ValueError , lambda :
474+ a .__dlpack__ (copy = True ))
475+ pytest .raises (ValueError , lambda :
476+ a .__dlpack__ (copy = None ))
477+ elif np .lib .NumpyVersion (np .__version__ ) < '2.1.0' :
478+ pytest .raises (NotImplementedError , lambda :
479+ a .__dlpack__ (dl_device = CPU_DEVICE ))
480+ a .__dlpack__ (dl_device = None )
481+ pytest .raises (NotImplementedError , lambda :
482+ a .__dlpack__ (max_version = (1 , 0 )))
483+ a .__dlpack__ (max_version = None )
484+ pytest .raises (NotImplementedError , lambda :
485+ a .__dlpack__ (copy = False ))
486+ pytest .raises (NotImplementedError , lambda :
487+ a .__dlpack__ (copy = True ))
488+ a .__dlpack__ (copy = None )
479489 else :
480- a .__dlpack__ (dl_device = CPU_DEVICE )
490+ a .__dlpack__ (dl_device = a . __dlpack_device__ () )
481491 a .__dlpack__ (dl_device = None )
482492 a .__dlpack__ (max_version = (1 , 0 ))
483493 a .__dlpack__ (max_version = None )
0 commit comments