-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Added ZeroSumNormal Distribution #4776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -65,6 +65,7 @@ | |||||||||||||||||
| "Lognormal", | ||||||||||||||||||
| "ChiSquared", | ||||||||||||||||||
| "HalfNormal", | ||||||||||||||||||
| "ZeroSumNormal", | ||||||||||||||||||
| "Wald", | ||||||||||||||||||
| "Pareto", | ||||||||||||||||||
| "InverseGamma", | ||||||||||||||||||
|
|
@@ -924,6 +925,67 @@ def logcdf(self, value): | |||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class ZeroSumNormal(Continuous): | ||||||||||||||||||
| def __init__(self, sigma=1, zerosum_dims=None, zerosum_axes=None, **kwargs): | ||||||||||||||||||
| shape = kwargs.get("shape", ()) | ||||||||||||||||||
| dims = kwargs.get("dims", None) | ||||||||||||||||||
| if isinstance(shape, int): | ||||||||||||||||||
| shape = (shape,) | ||||||||||||||||||
|
|
||||||||||||||||||
| if isinstance(dims, str): | ||||||||||||||||||
| dims = (dims,) | ||||||||||||||||||
|
|
||||||||||||||||||
| self.mu = self.median = self.mode = tt.zeros(shape) | ||||||||||||||||||
| self.sigma = tt.as_tensor_variable(sigma) | ||||||||||||||||||
|
|
||||||||||||||||||
| if zerosum_dims is None and zerosum_axes is None: | ||||||||||||||||||
| if shape: | ||||||||||||||||||
| zerosum_axes = (-1,) | ||||||||||||||||||
| else: | ||||||||||||||||||
| zerosum_axes = () | ||||||||||||||||||
|
|
||||||||||||||||||
| if isinstance(zerosum_axes, int): | ||||||||||||||||||
| zerosum_axes = (zerosum_axes,) | ||||||||||||||||||
|
|
||||||||||||||||||
| if isinstance(zerosum_dims, str): | ||||||||||||||||||
| zerosum_dims = (zerosum_dims,) | ||||||||||||||||||
|
|
||||||||||||||||||
| if zerosum_axes is not None and zerosum_dims is not None: | ||||||||||||||||||
| raise ValueError("Only one of zerosum_axes and zerosum_dims can be specified.") | ||||||||||||||||||
|
|
||||||||||||||||||
| if zerosum_dims is not None: | ||||||||||||||||||
| if dims is None: | ||||||||||||||||||
| raise ValueError("zerosum_dims can only be used with the dims kwargs.") | ||||||||||||||||||
| zerosum_axes = [] | ||||||||||||||||||
| for dim in zerosum_dims: | ||||||||||||||||||
| zerosum_axes.append(dims.index(dim)) | ||||||||||||||||||
| self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes] | ||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enforcing positive axis here leads to problems when you draw samples from the prior predictive. It's better to replace this line with this
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| if "transform" not in kwargs or kwargs["transform"] == None: | ||||||||||||||||||
| kwargs["transform"] = transforms.ZeroSumTransform(zerosum_axes) | ||||||||||||||||||
|
|
||||||||||||||||||
| super().__init__(**kwargs) | ||||||||||||||||||
|
|
||||||||||||||||||
| def logp(self, value): | ||||||||||||||||||
| return Normal.dist(sigma=self.sigma).logp(value) | ||||||||||||||||||
|
||||||||||||||||||
| return Normal.dist(sigma=self.sigma).logp(value) | |
| zerosums = [tt.all(tt.abs_(tt.mean(x, axis=axis)) <= 1e-9) for axis in self.zerosum_axes] | |
| return bound( | |
| pm.Normal.dist(sigma=self.sigma).logp(x), | |
| tt.all(self.sigma > 0), | |
| broadcast_conditions=False, | |
| *zerosums, | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I came across this wiki section that talks about the degenerate MvNormal case (which is what we have with the ZeroSumNormal). We could use that formula as the expected logp value and test if the logp that we are using in the distribution matches it. The expected logp would look something like this:
def pseudo_log_det(A, tol=1e-13):
v, w = np.linalg.eigh(A)
return np.sum(np.log(np.where(np.abs(v) >= tol, v, 1)), axis=-1)
def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
psdet = 0.5 * pseudo_log_det(2 * np.pi * cov)
exp = 0.5 * (value[..., None, :] @ np.linalg.pinv(cov) @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran a few tests with the logp and it looks like the logp that we are using in this PR, doesn't match what one would expect from a degenerate multivariate normal distribution. In my comment above, I posted what a degenerate MvNormal logp looks like. For this particular problem, where we know that we have only one eigenvector with zero eigenvalue, we can re-write the logp as:
def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
v, w = np.linalg.eigh(cov)
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi))
cov_pinv = w[:, 1:] @ np.diag(1 / v[1:]) @ w[:, 1:].T
exp = 0.5 * (value[..., None, :] @ cov_pinv @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)This is different from the logp that we are currently using in this PR. The difference is in the normalization constant:
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi)). In particular, since, all eigenvalues v except the first one are the same and are equal to sigma**2, psdet = (n - 1) * (0.5 * np.log(2 * np.pi) + np.log(np.sigma)). Whereas, with the assumed pm.Normal.dist(sigma=self.sigma).logp(x) the normalization factor we are getting is:
psdet = n * (0.5 * np.log(2 * np.pi) + np.log(np.sigma))This means that we have to multiply the logp that we are using by (n-1)/n (in the case where only one axis sums to zero) to get the correct log probability density. I'll check what happens when more than one axes has to zerosum.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,7 +98,7 @@ def __new__(cls, name, *args, **kwargs): | |
| raise TypeError("observed needs to be data but got: {}".format(type(data))) | ||
| total_size = kwargs.pop("total_size", None) | ||
|
|
||
| dims = kwargs.pop("dims", None) | ||
| dims = kwargs["dims"] if "dims" in kwargs else None | ||
|
||
| has_shape = "shape" in kwargs | ||
| shape = kwargs.pop("shape", None) | ||
| if dims is not None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes no sense to have a
ZeroSumNormalwhenshape=()orNone. In that case, the RV should also be exactly equal to zero. I think that we should test ifshape is None or len(shape) == 0and raise aValueErrorin that case. Something that says,ZeroSumNormalis defined only for RVs that are not scalar.