1212from . import pytest_helpers as ph
1313from . import shape_helpers as sh
1414from . import xps
15+ from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
1516from .typing import DataType , Scalar
1617
1718pytestmark = pytest .mark .ci
@@ -245,11 +246,25 @@ def test_asarray_scalars(shape, data):
245246 ph .assert_scalar_equals ("asarray" , scalar_type , idx , v , v_expect , ** kw )
246247
247248
248- @given (xps .arrays (dtype = xps .scalar_dtypes (), shape = hh .shapes ()), st .data ())
249- def test_asarray_arrays (x , data ):
250- # TODO: test other valid dtypes
249+ def scalar_eq (s1 : Scalar , s2 : Scalar ) -> bool :
250+ if math .isnan (s1 ):
251+ return math .isnan (s2 )
252+ else :
253+ return s1 == s2
254+
255+
256+ @given (
257+ shape = hh .shapes (),
258+ dtypes = oneway_promotable_dtypes (dh .all_dtypes ),
259+ data = st .data (),
260+ )
261+ def test_asarray_arrays (shape , dtypes , data ):
262+ x = data .draw (xps .arrays (dtype = dtypes .input_dtype , shape = shape ), label = "x" )
263+ dtypes_strat = st .just (dtypes .input_dtype )
264+ if dtypes .input_dtype == dtypes .result_dtype :
265+ dtypes_strat |= st .none ()
251266 kw = data .draw (
252- hh .kwargs (dtype = st . none () | st . just ( x . dtype ) , copy = st .none () | st .booleans ()),
267+ hh .kwargs (dtype = dtypes_strat , copy = st .none () | st .booleans ()),
253268 label = "kw" ,
254269 )
255270
@@ -261,27 +276,35 @@ def test_asarray_arrays(x, data):
261276 else :
262277 ph .assert_kw_dtype ("asarray" , dtype , out .dtype )
263278 ph .assert_shape ("asarray" , out .shape , x .shape )
264- if dtype is None or dtype == x .dtype :
265- ph .assert_array_elements ("asarray" , out , x , ** kw )
266- else :
267- pass # TODO
279+ ph .assert_array_elements ("asarray" , out , x , ** kw )
268280 copy = kw .get ("copy" , None )
269281 if copy is not None :
282+ stype = dh .get_scalar_type (x .dtype )
270283 idx = data .draw (xps .indices (x .shape , max_dims = 0 ), label = "mutating idx" )
271- _dtype = x .dtype if dtype is None else dtype
272- old_value = x [idx ]
284+ old_value = stype (x [idx ])
285+ scalar_strat = xps .from_dtype (dtypes .input_dtype ).filter (
286+ lambda n : not scalar_eq (n , old_value )
287+ )
273288 value = data .draw (
274- xps . arrays ( dtype = _dtype , shape = ()). filter (lambda y : y != old_value ),
289+ scalar_strat | scalar_strat . map (lambda n : xp . asarray ( n , dtype = x . dtype ) ),
275290 label = "mutating value" ,
276291 )
277292 x [idx ] = value
278293 note (f"mutated { x = } " )
294+ # sanity check
295+ ph .assert_scalar_equals (
296+ "__setitem__" , stype , idx , stype (x [idx ]), value , repr_name = "x"
297+ )
298+ new_out_value = stype (out [idx ])
299+ f_out = f"{ sh .fmt_idx ('out' , idx )} ={ new_out_value } "
279300 if copy :
280- assert not xp .all (
281- out == x
282- ), f"xp.all(out == x)=True, but should be False after x was mutated\n { out = } "
283- elif copy is False :
284- pass # TODO
301+ assert scalar_eq (
302+ new_out_value , old_value
303+ ), f"{ f_out } , but should be { old_value } even after x was mutated"
304+ else :
305+ assert scalar_eq (
306+ new_out_value , value
307+ ), f"{ f_out } , but should be { value } after x was mutated"
285308
286309
287310@given (hh .shapes (), hh .kwargs (dtype = st .none () | hh .shared_dtypes ))
0 commit comments