@@ -112,6 +112,7 @@ def test_asarray_cross_library(source_library, target_library, request):
112112
113113 assert is_tgt_type (b ), f"Expected { b } to be a { tgt_lib .ndarray } , but was { type (b )} "
114114
115+
115116@pytest .mark .parametrize ("library" , wrapped_libraries )
116117def test_asarray_copy (library ):
117118 # Note, we have this test here because the test suite currently doesn't
@@ -130,41 +131,57 @@ def test_asarray_copy(library):
130131 else :
131132 supports_copy_false = True
132133
134+ # Tests for copy=True
133135 a = asarray ([1 ])
134136 b = asarray (a , copy = True )
135137 assert is_lib_func (b )
136138 a [0 ] = 0
137139 assert all (b [0 ] == 1 )
138140 assert all (a [0 ] == 0 )
139141
142+ a = asarray ([1 ])
143+ b = asarray (a , copy = True , dtype = a .dtype )
144+ assert is_lib_func (b )
145+ a [0 ] = 0
146+ assert all (b [0 ] == 1 )
147+ assert all (a [0 ] == 0 )
148+
149+ # Tests for copy=False
140150 a = asarray ([1 ])
141151 if supports_copy_false :
142152 b = asarray (a , copy = False )
143153 assert is_lib_func (b )
144154 a [0 ] = 0
145155 assert all (b [0 ] == 0 )
146156 else :
147- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False ))
157+ with pytest .raises (NotImplementedError ):
158+ asarray (a , copy = False )
148159
149160 a = asarray ([1 ])
150161 if supports_copy_false :
151- pytest .raises (ValueError , lambda : asarray ( a , copy = False ,
152- dtype = xp .float64 ) )
162+ with pytest .raises (ValueError ):
163+ asarray ( a , copy = False , dtype = xp .float64 )
153164 else :
154- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False , dtype = xp .float64 ))
165+ with pytest .raises (NotImplementedError ):
166+ asarray (a , copy = False , dtype = xp .float64 )
155167
168+ # Tests for copy=None
169+ # Do not test whether the buffer is shared or not after copy=None.
170+ # A library should have the freedom to alter its behaviour
171+ # without treating it as a breaking change.
156172 a = asarray ([1 ])
157173 b = asarray (a , copy = None )
158174 assert is_lib_func (b )
159175 a [0 ] = 0
160- assert all (b [0 ] == 0 )
176+ assert all (( b [0 ] == 1.0 ) | ( b [ 0 ] == 0.0 ) )
161177
162178 a = asarray ([1.0 ], dtype = xp .float32 )
163179 assert a .dtype == xp .float32
164180 b = asarray (a , dtype = xp .float64 , copy = None )
165181 assert is_lib_func (b )
166182 assert b .dtype == xp .float64
167183 a [0 ] = 0.0
184+ # dtype change must always trigger a copy
168185 assert all (b [0 ] == 1.0 )
169186
170187 a = asarray ([1.0 ], dtype = xp .float64 )
@@ -173,16 +190,18 @@ def test_asarray_copy(library):
173190 assert is_lib_func (b )
174191 assert b .dtype == xp .float64
175192 a [0 ] = 0.0
176- assert all (b [0 ] == 0.0 )
193+ assert all (( b [0 ] == 1.0 ) | ( b [ 0 ] == 0.0 ) )
177194
178195 # Python built-in types
179196 for obj in [True , 0 , 0.0 , 0j , [0 ], [[0 ]]]:
180197 asarray (obj , copy = True ) # No error
181198 asarray (obj , copy = None ) # No error
182199 if supports_copy_false :
183- pytest .raises (ValueError , lambda : asarray (obj , copy = False ))
200+ with pytest .raises (ValueError ):
201+ asarray (obj , copy = False )
184202 else :
185- pytest .raises (NotImplementedError , lambda : asarray (obj , copy = False ))
203+ with pytest .raises (NotImplementedError ):
204+ asarray (obj , copy = False )
186205
187206 # Use the standard library array to test the buffer protocol
188207 a = array .array ('f' , [1.0 ])
@@ -198,14 +217,11 @@ def test_asarray_copy(library):
198217 a [0 ] = 0.0
199218 assert all (b [0 ] == 0.0 )
200219 else :
201- pytest .raises (NotImplementedError , lambda : asarray (a , copy = False ))
220+ with pytest .raises (NotImplementedError ):
221+ asarray (a , copy = False )
202222
203223 a = array .array ('f' , [1.0 ])
204224 b = asarray (a , copy = None )
205225 assert is_lib_func (b )
206226 a [0 ] = 0.0
207- if library == 'cupy' :
208- # A copy is required for libraries where the default device is not CPU
209- assert all (b [0 ] == 1.0 )
210- else :
211- assert all (b [0 ] == 0.0 )
227+ assert all ((b [0 ] == 1.0 ) | (b [0 ] == 0.0 ))
0 commit comments