@@ -1481,3 +1481,169 @@ def c_code(self, *args, **kwargs):
14811481
14821482
14831483betainc_der = BetaIncDer (upgrade_to_float_no_complex , name = "betainc_der" )
1484+
1485+
1486+ class Hyp2F1 (ScalarOp ):
1487+ """
1488+ Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489+
1490+ """
1491+
1492+ nin = 4
1493+ nfunc_spec = ("scipy.special.hyp2f1" , 4 , 1 )
1494+
1495+ @staticmethod
1496+ def st_impl (a , b , c , z ):
1497+ return scipy .special .hyp2f1 (a , b , c , z )
1498+
1499+ def impl (self , a , b , c , z ):
1500+ return Hyp2F1 .st_impl (a , b , c , z )
1501+
1502+ def grad (self , inputs , grads ):
1503+ a , b , c , z = inputs
1504+ (gz ,) = grads
1505+ return [
1506+ gz * hyp2f1_der (a , b , c , z , wrt = 0 ),
1507+ gz * hyp2f1_der (a , b , c , z , wrt = 1 ),
1508+ gz * hyp2f1_der (a , b , c , z , wrt = 2 ),
1509+ gz * ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z ),
1510+ ]
1511+
1512+ def c_code (self , * args , ** kwargs ):
1513+ raise NotImplementedError ()
1514+
1515+
1516+ hyp2f1 = Hyp2F1 (upgrade_to_float , name = "hyp2f1" )
1517+
1518+
1519+ class Hyp2F1Der (ScalarOp ):
1520+ """
1521+ Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1522+
1523+ Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1524+ """
1525+
1526+ nin = 5
1527+
1528+ def impl (self , a , b , c , z , wrt ):
1529+ def check_2f1_converges (a , b , c , z ) -> bool :
1530+ num_terms = 0
1531+ is_polynomial = False
1532+
1533+ def is_nonpositive_integer (x ):
1534+ return x <= 0 and x .is_integer ()
1535+
1536+ if is_nonpositive_integer (a ) and abs (a ) >= num_terms :
1537+ is_polynomial = True
1538+ num_terms = int (np .floor (abs (a )))
1539+ if is_nonpositive_integer (b ) and abs (b ) >= num_terms :
1540+ is_polynomial = True
1541+ num_terms = int (np .floor (abs (b )))
1542+
1543+ is_undefined = is_nonpositive_integer (c ) and abs (c ) <= num_terms
1544+
1545+ return not is_undefined and (
1546+ is_polynomial or np .abs (z ) < 1 or (np .abs (z ) == 1 and c > (a + b ))
1547+ )
1548+
1549+ def compute_grad_2f1 (a , b , c , z , wrt ):
1550+ """
1551+ Notes
1552+ -----
1553+ The algorithm can be derived by looking at the ratio of two successive terms in the series
1554+ β_{k+1}/β_{k} = A(k)/B(k)
1555+ β_{k+1} = A(k)/B(k) * β_{k}
1556+ d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1557+
1558+ In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1559+
1560+ The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1561+ by dropping the respective term
1562+ d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1563+ d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1564+ d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1565+
1566+ The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1567+ tracking their signs.
1568+ """
1569+
1570+ wrt_a = wrt_b = False
1571+ if wrt == 0 :
1572+ wrt_a = True
1573+ elif wrt == 1 :
1574+ wrt_b = True
1575+ elif wrt != 2 :
1576+ raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1577+
1578+ min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1579+ max_steps = int (1e6 )
1580+ precision = 1e-14
1581+
1582+ res = 0
1583+
1584+ if z == 0 :
1585+ return res
1586+
1587+ log_g_old = - np .inf
1588+ log_t_old = 0.0
1589+ log_t_new = 0.0
1590+ sign_z = np .sign (z )
1591+ log_z = np .log (np .abs (z ))
1592+
1593+ log_g_old_sign = 1
1594+ log_t_old_sign = 1
1595+ log_t_new_sign = 1
1596+ sign_zk = sign_z
1597+
1598+ for k in range (max_steps ):
1599+ p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1600+ if p == 0 :
1601+ return res
1602+ log_t_new += np .log (np .abs (p )) + log_z
1603+ log_t_new_sign = np .sign (p ) * log_t_new_sign
1604+
1605+ term = log_g_old_sign * log_t_old_sign * np .exp (log_g_old - log_t_old )
1606+ if wrt_a :
1607+ term += np .reciprocal (a + k )
1608+ elif wrt_b :
1609+ term += np .reciprocal (b + k )
1610+ else :
1611+ term -= np .reciprocal (c + k )
1612+
1613+ log_g_old = log_t_new + np .log (np .abs (term ))
1614+ log_g_old_sign = np .sign (term ) * log_t_new_sign
1615+ g_current = log_g_old_sign * np .exp (log_g_old ) * sign_zk
1616+ res += g_current
1617+
1618+ log_t_old = log_t_new
1619+ log_t_old_sign = log_t_new_sign
1620+ sign_zk *= sign_z
1621+
1622+ if k >= min_steps and np .abs (g_current ) <= precision :
1623+ return res
1624+
1625+ warnings .warn (
1626+ f"hyp2f1_der did not converge after { k } iterations" ,
1627+ RuntimeWarning ,
1628+ )
1629+ return np .nan
1630+
1631+ # TODO: We could implement the Euler transform to expand supported domain, as Stan does
1632+ if not check_2f1_converges (a , b , c , z ):
1633+ warnings .warn (
1634+ f"Hyp2F1 does not meet convergence conditions with given arguments a={ a } , b={ b } , c={ c } , z={ z } " ,
1635+ RuntimeWarning ,
1636+ )
1637+ return np .nan
1638+
1639+ return compute_grad_2f1 (a , b , c , z , wrt = wrt )
1640+
1641+ def __call__ (self , a , b , c , z , wrt ):
1642+ # This allows wrt to be a keyword argument
1643+ return super ().__call__ (a , b , c , z , wrt )
1644+
1645+ def c_code (self , * args , ** kwargs ):
1646+ raise NotImplementedError ()
1647+
1648+
1649+ hyp2f1_der = Hyp2F1Der (upgrade_to_float , name = "hyp2f1_der" )
0 commit comments