diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index fd6e414605..886364ee55 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -53,6 +53,7 @@ from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.variable import TensorConstant, TensorVariable +from pymc.distributions.custom import CustomDist from pymc.logprob.abstract import _logprob_helper from pymc.logprob.basic import TensorLike, icdf from pymc.pytensorf import normalize_rng_param @@ -92,7 +93,7 @@ def polyagamma_cdf(*args, **kwargs): from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none from pymc.distributions.transforms import _default_transform -from pymc.math import invlogit, logdiffexp, logit +from pymc.math import invlogit, logdiffexp __all__ = [ "AsymmetricLaplace", @@ -3603,28 +3604,7 @@ def icdf(value, mu, s): ) -class LogitNormalRV(SymbolicRandomVariable): - name = "logit_normal" - extended_signature = "[rng],[size],(),()->[rng],()" - _print_name = ("LogitNormal", "\\operatorname{LogitNormal}") - - @classmethod - def rv_op(cls, mu, sigma, *, size=None, rng=None): - mu = pt.as_tensor(mu) - sigma = pt.as_tensor(sigma) - rng = normalize_rng_param(rng) - size = normalize_size_param(size) - - next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs - draws = pt.expit(normal_draws) - - return cls( - inputs=[rng, size, mu, sigma], - outputs=[next_rng, draws], - )(rng, size, mu, sigma) - - -class LogitNormal(UnitContinuous): +class LogitNormal: r""" Logit-Normal distribution. @@ -3672,37 +3652,26 @@ class LogitNormal(UnitContinuous): Defaults to 1. """ - rv_type = LogitNormalRV - rv_op = LogitNormalRV.rv_op + @staticmethod + def logitnormal_dist(mu, sigma, size): + return invlogit(Normal.dist(mu=mu, sigma=sigma, size=size)) - @classmethod - def dist(cls, mu=0, sigma=None, tau=None, **kwargs): + def __new__(cls, name, mu=0, sigma=None, tau=None, **kwargs): _, sigma = get_tau_sigma(tau=tau, sigma=sigma) - return super().dist([mu, sigma], **kwargs) - - def support_point(rv, size, mu, sigma): - median, _ = pt.broadcast_arrays(invlogit(mu), sigma) - if not rv_size_is_none(size): - median = pt.full(size, median) - return median - - def logp(value, mu, sigma): - tau, _ = get_tau_sigma(sigma=sigma) - - res = pt.switch( - pt.or_(pt.le(value, 0), pt.ge(value, 1)), - -np.inf, - ( - -0.5 * tau * (logit(value) - mu) ** 2 - + 0.5 * pt.log(tau / (2.0 * np.pi)) - - pt.log(value * (1 - value)) - ), + return CustomDist( + name, + mu, + sigma, + dist=cls.logitnormal_dist, + class_name="LogitNormal", + **kwargs, ) - return check_parameters( - res, - tau > 0, - msg="tau > 0", + @classmethod + def dist(cls, mu=0, sigma=None, tau=None, **kwargs): + _, sigma = get_tau_sigma(tau=tau, sigma=sigma) + return CustomDist.dist( + mu, sigma, dist=cls.logitnormal_dist, class_name="LogitNormal", **kwargs ) diff --git a/pymc/distributions/moments/means.py b/pymc/distributions/moments/means.py index 0e3129935e..34687a7ba2 100644 --- a/pymc/distributions/moments/means.py +++ b/pymc/distributions/moments/means.py @@ -59,7 +59,6 @@ HalfFlatRV, HalfStudentTRV, KumaraswamyRV, - LogitNormalRV, MoyalRV, PolyaGammaRV, RiceRV, @@ -290,11 +289,6 @@ def logistic_mean(op, rv, rng, size, mu, s): return maybe_resize(pt.broadcast_arrays(mu, s)[0], size) -@_mean.register(LogitNormalRV) -def logitnormal_mean(op, rv, rng, size, mu, sigma): - raise UndefinedMomentException("The mean of the LogitNormal distribution is undefined") - - @_mean.register(LogNormalRV) def lognormal_mean(op, rv, rng, size, mu, sigma): return maybe_resize(pt.exp(mu + 0.5 * sigma**2), size) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8b5eac7b16..8d2bbacd26 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -178,6 +178,7 @@ class MeasurableTransform(MeasurableElemwise): Erf, Erfc, Erfcx, + Sigmoid, ) # Cannot use `transform` as name because it would clash with the property added by @@ -227,7 +228,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian) -MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf) +MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf, Sigmoid) MONOTONICALLY_DECREASING_OPS = (Erfc, Erfcx) @@ -300,7 +301,18 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs) value = pt.switch(pt.lt(scale, 0), 1 - value, value) elif isinstance(op.scalar_op, Pow): if op.transform_elemwise.power < 0: - raise NotImplementedError + # Note: Negative even powers will be rejected below when inverting the transform + # For the remaining negative powers the function is decreasing with a jump around 0 + # We adjust the value with the mass below zero. + # For non-negative RVs with cdf(0)=0, it simplifies to 1 - value + cdf_zero = pt.exp(_logcdf_helper(measurable_input, 0)) + # Use nan to not mask invalid values accidentally + value = pt.switch((value >= 0) & (value <= 1), value, np.nan) + value = pt.switch( + (cdf_zero > 0) & (value < cdf_zero), + cdf_zero - value, + 1 + cdf_zero - value, + ) else: raise NotImplementedError diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 7209382666..e1e9b467d5 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -872,6 +872,12 @@ def test_logitnormal(self): ), decimal=select_by_precision(float64=6, float32=1), ) + check_icdf( + pm.LogitNormal, + {"mu": R, "sigma": Rplus}, + lambda q, mu, sigma: sp.expit(mu + sigma * st.norm.ppf(q)), + decimal=select_by_precision(float64=12, float32=5), + ) @pytest.mark.skipif( condition=(pytensor.config.floatX == "float32"), diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 691c696e8f..c9aeaa8abf 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -379,9 +379,7 @@ def test_reciprocal_rv_transform(self, numerator): x_vv = x_rv.clone() x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv)) - - with pytest.raises(NotImplementedError): - icdf(x_rv, x_vv) + x_icdf_fn = pytensor.function([x_vv], icdf(x_rv, x_vv)) x_test_val = np.r_[-0.5, 1.5] np.testing.assert_allclose( @@ -392,6 +390,10 @@ def test_reciprocal_rv_transform(self, numerator): x_logcdf_fn(x_test_val), sp.stats.invgamma(shape, scale=scale * numerator).logcdf(x_test_val), ) + np.testing.assert_allclose( + x_icdf_fn(x_test_val), + sp.stats.invgamma(shape, scale=scale * numerator).ppf(x_test_val), + ) def test_reciprocal_real_rv_transform(self): # 1 / Cauchy(mu, sigma) = Cauchy(mu / (mu^2 + sigma ^2), sigma / (mu ^ 2, sigma ^ 2)) @@ -406,8 +408,10 @@ def test_reciprocal_real_rv_transform(self): logcdf(test_rv, test_value).eval(), sp.stats.cauchy(1 / 5, 2 / 5).logcdf(test_value), ) - with pytest.raises(NotImplementedError): - icdf(test_rv, test_value) + np.testing.assert_allclose( + icdf(test_rv, test_value).eval(), + sp.stats.cauchy(1 / 5, 2 / 5).ppf(test_value), + ) def test_sqr_transform(self): # The square of a normal with unit variance is a noncentral chi-square with 1 df and nc = mean ** 2