Skip to content

Commit f9cb757

Browse files
jnetzel1ricardoV94
andcommitted
Use automatic logprob for LogitNormal and enable icdf
Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
1 parent 41f33ad commit f9cb757

File tree

3 files changed

+43
-54
lines changed

3 files changed

+43
-54
lines changed

pymc/distributions/continuous.py

Lines changed: 19 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pytensor.tensor.random.utils import normalize_size_param
5454
from pytensor.tensor.variable import TensorConstant, TensorVariable
5555

56+
from pymc.distributions.custom import CustomDist
5657
from pymc.logprob.abstract import _logprob_helper
5758
from pymc.logprob.basic import TensorLike, icdf
5859
from pymc.pytensorf import normalize_rng_param
@@ -92,7 +93,7 @@ def polyagamma_cdf(*args, **kwargs):
9293
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable
9394
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
9495
from pymc.distributions.transforms import _default_transform
95-
from pymc.math import invlogit, logdiffexp, logit
96+
from pymc.math import invlogit, logdiffexp
9697

9798
__all__ = [
9899
"AsymmetricLaplace",
@@ -3603,28 +3604,7 @@ def icdf(value, mu, s):
36033604
)
36043605

36053606

3606-
class LogitNormalRV(SymbolicRandomVariable):
3607-
name = "logit_normal"
3608-
extended_signature = "[rng],[size],(),()->[rng],()"
3609-
_print_name = ("LogitNormal", "\\operatorname{LogitNormal}")
3610-
3611-
@classmethod
3612-
def rv_op(cls, mu, sigma, *, size=None, rng=None):
3613-
mu = pt.as_tensor(mu)
3614-
sigma = pt.as_tensor(sigma)
3615-
rng = normalize_rng_param(rng)
3616-
size = normalize_size_param(size)
3617-
3618-
next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs
3619-
draws = pt.expit(normal_draws)
3620-
3621-
return cls(
3622-
inputs=[rng, size, mu, sigma],
3623-
outputs=[next_rng, draws],
3624-
)(rng, size, mu, sigma)
3625-
3626-
3627-
class LogitNormal(UnitContinuous):
3607+
class LogitNormal:
36283608
r"""
36293609
Logit-Normal distribution.
36303610
@@ -3672,37 +3652,26 @@ class LogitNormal(UnitContinuous):
36723652
Defaults to 1.
36733653
"""
36743654

3675-
rv_type = LogitNormalRV
3676-
rv_op = LogitNormalRV.rv_op
3655+
@staticmethod
3656+
def logitnormal_dist(mu, sigma, size):
3657+
return invlogit(Normal.dist(mu=mu, sigma=sigma, size=size))
36773658

3678-
@classmethod
3679-
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
3659+
def __new__(cls, name, mu=0, sigma=None, tau=None, **kwargs):
36803660
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
3681-
return super().dist([mu, sigma], **kwargs)
3682-
3683-
def support_point(rv, size, mu, sigma):
3684-
median, _ = pt.broadcast_arrays(invlogit(mu), sigma)
3685-
if not rv_size_is_none(size):
3686-
median = pt.full(size, median)
3687-
return median
3688-
3689-
def logp(value, mu, sigma):
3690-
tau, _ = get_tau_sigma(sigma=sigma)
3691-
3692-
res = pt.switch(
3693-
pt.or_(pt.le(value, 0), pt.ge(value, 1)),
3694-
-np.inf,
3695-
(
3696-
-0.5 * tau * (logit(value) - mu) ** 2
3697-
+ 0.5 * pt.log(tau / (2.0 * np.pi))
3698-
- pt.log(value * (1 - value))
3699-
),
3661+
return CustomDist(
3662+
name,
3663+
mu,
3664+
sigma,
3665+
dist=cls.logitnormal_dist,
3666+
class_name="LogitNormal",
3667+
**kwargs,
37003668
)
37013669

3702-
return check_parameters(
3703-
res,
3704-
tau > 0,
3705-
msg="tau > 0",
3670+
@classmethod
3671+
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
3672+
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
3673+
return CustomDist.dist(
3674+
mu, sigma, dist=cls.logitnormal_dist, class_name="LogitNormal", **kwargs
37063675
)
37073676

37083677

pymc/logprob/transforms.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class MeasurableTransform(MeasurableElemwise):
178178
Erf,
179179
Erfc,
180180
Erfcx,
181+
Sigmoid,
181182
)
182183

183184
# Cannot use `transform` as name because it would clash with the property added by
@@ -227,7 +228,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
227228
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)
228229

229230

230-
MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf)
231+
MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf, Sigmoid)
231232
MONOTONICALLY_DECREASING_OPS = (Erfc, Erfcx)
232233

233234

@@ -401,7 +402,7 @@ def measurable_special_log_to_log(fgraph, node):
401402
return [pt.log(inp) / pt.log(10)]
402403

403404

404-
@node_rewriter([expm1, sigmoid, exp2])
405+
@node_rewriter([expm1, exp2])
405406
def measurable_special_exp_to_exp(fgraph, node):
406407
"""Convert expm1, sigmoid, and exp2 of `MeasurableVariable`s to xp form."""
407408
if not filter_measurable_variables(node.inputs):
@@ -412,8 +413,6 @@ def measurable_special_exp_to_exp(fgraph, node):
412413
return [pt.exp(pt.log(2) * inp)]
413414
if isinstance(node.op.scalar_op, Expm1):
414415
return [pt.add(pt.exp(inp), -1)]
415-
if isinstance(node.op.scalar_op, Sigmoid):
416-
return [1 / (1 + pt.exp(-inp))]
417416

418417

419418
@node_rewriter([pow])
@@ -451,6 +450,7 @@ def measurable_power_exponent_to_exp(fgraph, node):
451450
erf,
452451
erfc,
453452
erfcx,
453+
sigmoid,
454454
]
455455
)
456456
def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Variable] | None:
@@ -526,6 +526,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia
526526
Erf: ErfTransform,
527527
Erfc: ErfcTransform,
528528
Erfcx: ErfcxTransform,
529+
Sigmoid: SigmoidTransform,
529530
}[type(scalar_op)]()
530531

531532
transform_op = MeasurableTransform(
@@ -968,6 +969,19 @@ def log_jac_det(self, value, *inputs):
968969
return pt.log(sigmoid_value) + pt.log1p(-sigmoid_value)
969970

970971

972+
class SigmoidTransform(Transform):
973+
name = "sigmoid"
974+
975+
def forward(self, value, *inputs):
976+
return sigmoid(value) # AKA invlogit/expit
977+
978+
def backward(self, value, *inputs):
979+
return pt.log(value) - pt.log1p(-value)
980+
981+
def log_jac_det(self, value, *inputs):
982+
return -pt.log(value) - pt.log1p(-value)
983+
984+
971985
class SimplexTransform(Transform):
972986
name = "simplex"
973987

tests/distributions/test_continuous.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,12 @@ def test_logitnormal(self):
872872
),
873873
decimal=select_by_precision(float64=6, float32=1),
874874
)
875+
check_icdf(
876+
pm.LogitNormal,
877+
{"mu": R, "sigma": Rplus},
878+
lambda q, mu, sigma: sp.expit(mu + sigma * st.norm.ppf(q)),
879+
decimal=select_by_precision(float64=12, float32=5),
880+
)
875881

876882
@pytest.mark.skipif(
877883
condition=(pytensor.config.floatX == "float32"),

0 commit comments

Comments
 (0)