@@ -1543,6 +1543,12 @@ def test_basic(self, xp: ModuleType):
15431543 expect = xp .asarray (3.0 , dtype = xp .float64 )
15441544 xp_assert_close (actual , expect )
15451545
1546+ def test_xp (self , xp : ModuleType ):
1547+ x = xp .asarray ([1 , 2 , 3 , 4 , 5 ])
1548+ actual = quantile (x , 0.5 , xp = xp )
1549+ expect = xp .asarray (3.0 , dtype = xp .float64 )
1550+ xp_assert_close (actual , expect )
1551+
15461552 def test_multiple_quantiles (self , xp : ModuleType ):
15471553 x = xp .asarray ([1 , 2 , 3 , 4 , 5 ])
15481554 actual = quantile (x , xp .asarray ([0.25 , 0.5 , 0.75 ]))
@@ -1729,15 +1735,49 @@ def test_invalid_q(self, xp: ModuleType):
17291735 ):
17301736 _ = quantile (x , - 0.5 )
17311737
1738+ def test_invalid_shape (self , xp : ModuleType ):
1739+ with pytest .raises (TypeError , match = "at least 1-dimensional" ):
1740+ _ = quantile (xp .asarray (3.0 ), 0.5 )
1741+ with pytest .raises (ValueError , match = "not compatible with the dimension" ):
1742+ _ = quantile (xp .asarray ([3.0 ]), 0.5 , axis = 1 )
1743+ # with weights:
1744+ method = "inverted_cdf"
1745+ shape = (2 , 3 , 4 )
1746+ with pytest .raises (ValueError , match = "dimension of `a` must be 1 or 2" ):
1747+ _ = quantile (
1748+ xp .ones (shape ), 0.5 , axis = 1 , weights = xp .ones (shape ), method = method
1749+ )
1750+ with pytest .raises (TypeError , match = "Axis must be specified" ):
1751+ _ = quantile (xp .ones ((2 , 3 )), 0.5 , weights = xp .ones (3 ), method = method )
1752+ with pytest .raises (ValueError , match = "Shape of weights must be consistent" ):
1753+ _ = quantile (
1754+ xp .ones ((2 , 3 )), 0.5 , axis = 0 , weights = xp .ones (3 ), method = method
1755+ )
1756+ with pytest .raises (ValueError , match = "Axis must be specified" ):
1757+ _ = quantile (xp .ones ((2 , 3 )), 0.5 , weights = xp .ones ((2 , 3 )), method = method )
1758+
1759+ def test_invalid_dtype (self , xp : ModuleType ):
1760+ with pytest .raises (ValueError , match = "`a` must have real dtype" ):
1761+ _ = quantile (xp .ones (5 , dtype = xp .bool ), 0.5 )
1762+
1763+ with pytest .raises (ValueError , match = "`q` must have real floating dtype" ):
1764+ _ = quantile (xp .ones (5 ), xp .asarray ([0 , 1 ]))
1765+
1766+ def test_invalid_method (self , xp : ModuleType ):
1767+ with pytest .raises (ValueError , match = "`method` must be one of" ):
1768+ _ = quantile (xp .ones (5 ), 0.5 , method = "invalid" )
1769+ # TODO: with weights?
1770+
1771+ def test_invalid_nan_policy (self , xp : ModuleType ):
1772+ with pytest .raises (ValueError , match = "`nan_policy` must be one of" ):
1773+ _ = quantile (xp .ones (5 ), 0.5 , nan_policy = "invalid" )
1774+
1775+ with pytest .raises (ValueError , match = "must be 'propagate'" ):
1776+ _ = quantile (xp .ones (5 ), 0.5 , nan_policy = "omit" )
1777+
17321778 def test_device (self , xp : ModuleType , device : Device ):
17331779 if hasattr (device , "type" ) and device .type == "meta" : # pyright: ignore[reportAttributeAccessIssue]
17341780 pytest .xfail ("No Tensor.item() on meta device" )
17351781 x = xp .asarray ([1 , 2 , 3 , 4 , 5 ], device = device )
17361782 actual = quantile (x , 0.5 )
17371783 assert get_device (actual ) == device
1738-
1739- def test_xp (self , xp : ModuleType ):
1740- x = xp .asarray ([1 , 2 , 3 , 4 , 5 ])
1741- actual = quantile (x , 0.5 , xp = xp )
1742- expect = xp .asarray (3.0 , dtype = xp .float64 )
1743- xp_assert_close (actual , expect )
0 commit comments