@@ -336,6 +336,7 @@ def test_xp(self, xp: ModuleType):
336336class TestIsClose :
337337 # FIXME use lazywhere to avoid warnings on inf
338338 @pytest .mark .filterwarnings ("ignore:invalid value encountered" )
339+ @pytest .mark .parametrize ("swap" , [False , True ])
339340 @pytest .mark .parametrize (
340341 ("a" , "b" ),
341342 [
@@ -353,9 +354,9 @@ class TestIsClose:
353354 (float ("inf" ), float ("inf" )),
354355 (float ("inf" ), 100.0 ),
355356 (float ("inf" ), float ("-inf" )),
357+ (float ("-inf" ), float ("-inf" )),
356358 (float ("nan" ), float ("nan" )),
357- (float ("nan" ), 0.0 ),
358- (0.0 , float ("nan" )),
359+ (float ("nan" ), 100.0 ),
359360 (1e6 , 1e6 + 1 ), # True - within rtol
360361 (1e6 , 1e6 + 100 ), # False - outside rtol
361362 (1e-6 , 1.1e-6 ), # False - outside atol
@@ -364,19 +365,20 @@ class TestIsClose:
364365 (1e6 + 0j , 1e6 + 100j ), # False - outside rtol
365366 ],
366367 )
367- def test_basic (self , a : float , b : float , xp : ModuleType ):
368+ def test_basic (self , a : float , b : float , swap : bool , xp : ModuleType ):
369+ if swap :
370+ b , a = a , b
368371 a_xp = xp .asarray (a )
369372 b_xp = xp .asarray (b )
370373
371374 xp_assert_equal (isclose (a_xp , b_xp ), xp .asarray (np .isclose (a , b )))
372375
373376 with warnings .catch_warnings ():
374377 warnings .simplefilter ("ignore" )
375- r_xp = xp .asarray (np .arange (10 ), dtype = a_xp .dtype )
376- ar_xp = a_xp * r_xp
377- br_xp = b_xp * r_xp
378378 ar_np = a * np .arange (10 )
379379 br_np = b * np .arange (10 )
380+ ar_xp = xp .asarray (ar_np )
381+ br_xp = xp .asarray (br_np )
380382
381383 xp_assert_equal (isclose (ar_xp , br_xp ), xp .asarray (np .isclose (ar_np , br_np )))
382384
@@ -395,14 +397,14 @@ def test_broadcast(self, dtype: str, xp: ModuleType):
395397 # FIXME use lazywhere to avoid warnings on inf
396398 @pytest .mark .filterwarnings ("ignore:invalid value encountered" )
397399 def test_some_inf (self , xp : ModuleType ):
398- a = xp .asarray ([0.0 , 1.0 , float ( " inf" ), float ( " inf" ), float ( " inf" ) ])
399- b = xp .asarray ([1e-9 , 1.0 , float ( " inf" ), float ( "- inf" ) , 2.0 ])
400+ a = xp .asarray ([0.0 , 1.0 , xp . inf , xp . inf , xp . inf ])
401+ b = xp .asarray ([1e-9 , 1.0 , xp . inf , - xp . inf , 2.0 ])
400402 actual = isclose (a , b )
401403 xp_assert_equal (actual , xp .asarray ([True , True , True , False , False ]))
402404
403405 def test_equal_nan (self , xp : ModuleType ):
404- a = xp .asarray ([float ( " nan" ), float ( " nan" ) , 1.0 ])
405- b = xp .asarray ([float ( " nan" ) , 1.0 , float ( " nan" ) ])
406+ a = xp .asarray ([xp . nan , xp . nan , 1.0 ])
407+ b = xp .asarray ([xp . nan , 1.0 , xp . nan ])
406408 xp_assert_equal (isclose (a , b ), xp .asarray ([False , False , False ]))
407409 xp_assert_equal (isclose (a , b , equal_nan = True ), xp .asarray ([True , False , False ]))
408410
0 commit comments