@@ -19,10 +19,18 @@ def non_complex_dtypes():
1919 return xps .boolean_dtypes () | hh .real_dtypes
2020
2121
22+ def numeric_dtypes ():
23+ return xps .boolean_dtypes () | hh .real_dtypes | hh .complex_dtypes
24+
25+
2226def float32 (n : Union [int , float ]) -> float :
2327 return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
2428
2529
30+ def _float_match_complex (complex_dtype ):
31+ return xp .float32 if complex_dtype == xp .complex64 else xp .float64
32+
33+
2634@given (
2735 x_dtype = non_complex_dtypes (),
2836 dtype = non_complex_dtypes (),
@@ -107,7 +115,7 @@ def test_broadcast_to(x, data):
107115 # TODO: test values
108116
109117
110- @given (_from = non_complex_dtypes (), to = non_complex_dtypes (), data = st .data ())
118+ @given (_from = numeric_dtypes (), to = numeric_dtypes (), data = st .data ())
111119def test_can_cast (_from , to , data ):
112120 from_ = data .draw (
113121 st .just (_from ) | hh .arrays (dtype = _from , shape = hh .shapes ()), label = "from_"
@@ -127,8 +135,15 @@ def test_can_cast(_from, to, data):
127135 break
128136 assert same_family is not None # sanity check
129137 if same_family :
130- from_min , from_max = dh .dtype_ranges [_from ]
131- to_min , to_max = dh .dtype_ranges [to ]
138+ from_dtype = (_float_match_complex (_from )
139+ if _from in (xp .complex64 , xp .complex128 )
140+ else _from )
141+ to_dtype = (_float_match_complex (to )
142+ if to in (xp .complex64 , xp .complex128 )
143+ else to )
144+
145+ from_min , from_max = dh .dtype_ranges [from_dtype ]
146+ to_min , to_max = dh .dtype_ranges [to_dtype ]
132147 expected = from_min >= to_min and from_max <= to_max
133148 else :
134149 expected = False
@@ -139,6 +154,7 @@ def test_can_cast(_from, to, data):
139154 assert out == expected , f"{ out = } , but should be { expected } { f_func } "
140155
141156
157+
142158@pytest .mark .parametrize ("dtype" , dh .real_float_dtypes )
143159def test_finfo (dtype ):
144160 out = xp .finfo (dtype )
0 commit comments