@@ -448,27 +448,49 @@ def test_iter():
448448 pytest .raises (TypeError , lambda : iter (ones ((3 , 3 ))))
449449
450450@pytest .mark .parametrize ("api_version" , ['2021.12' , '2022.12' , '2023.12' ])
451- def dlpack_2023_12 (api_version ):
451+ def test_dlpack_2023_12 (api_version ):
452452 if api_version == '2021.12' :
453453 with pytest .warns (UserWarning ):
454454 set_array_api_strict_flags (api_version = api_version )
455455 else :
456456 set_array_api_strict_flags (api_version = api_version )
457457
458458 a = asarray ([1 , 2 , 3 ], dtype = int8 )
459-
460- # Do not error
459+ # Never an error
461460 a .__dlpack__ ()
462- a .__dlpack__ (dl_device = CPU_DEVICE )
463- a .__dlpack__ (dl_device = None )
464- a .__dlpack__ (max_version = (1 , 0 ))
465- a .__dlpack__ (max_version = None )
466- a .__dlpack__ (copy = False )
467- a .__dlpack__ (copy = True )
468- a .__dlpack__ (copy = None )
469-
470- x = np .from_dlpack (a )
471- assert isinstance (x , np .ndarray )
472- assert x .dtype == np .int8
473- assert x .shape == (3 ,)
474- assert np .all (x == np .asarray ([1 , 2 , 3 ]))
461+
462+ if api_version < '2023.12' :
463+ pytest .raises (ValueError , lambda :
464+ a .__dlpack__ (dl_device = a .__dlpack_device__ ()))
465+ pytest .raises (ValueError , lambda :
466+ a .__dlpack__ (dl_device = None ))
467+ pytest .raises (ValueError , lambda :
468+ a .__dlpack__ (max_version = (1 , 0 )))
469+ pytest .raises (ValueError , lambda :
470+ a .__dlpack__ (max_version = None ))
471+ pytest .raises (ValueError , lambda :
472+ a .__dlpack__ (copy = False ))
473+ pytest .raises (ValueError , lambda :
474+ a .__dlpack__ (copy = True ))
475+ pytest .raises (ValueError , lambda :
476+ a .__dlpack__ (copy = None ))
477+ elif np .lib .NumpyVersion (np .__version__ ) < '2.1.0' :
478+ pytest .raises (NotImplementedError , lambda :
479+ a .__dlpack__ (dl_device = CPU_DEVICE ))
480+ a .__dlpack__ (dl_device = None )
481+ pytest .raises (NotImplementedError , lambda :
482+ a .__dlpack__ (max_version = (1 , 0 )))
483+ a .__dlpack__ (max_version = None )
484+ pytest .raises (NotImplementedError , lambda :
485+ a .__dlpack__ (copy = False ))
486+ pytest .raises (NotImplementedError , lambda :
487+ a .__dlpack__ (copy = True ))
488+ a .__dlpack__ (copy = None )
489+ else :
490+ a .__dlpack__ (dl_device = a .__dlpack_device__ ())
491+ a .__dlpack__ (dl_device = None )
492+ a .__dlpack__ (max_version = (1 , 0 ))
493+ a .__dlpack__ (max_version = None )
494+ a .__dlpack__ (copy = False )
495+ a .__dlpack__ (copy = True )
496+ a .__dlpack__ (copy = None )
0 commit comments