From 25ddb5a082514bcc633160657247087e916c0796 Mon Sep 17 00:00:00 2001 From: atheendre130505 Date: Fri, 31 Oct 2025 11:30:49 +0530 Subject: [PATCH 1/2] [BUG] Replace deprecated batched_dot with pt.sum in KroneckerNormal - Fixes Issue #7878 - Replace pt.batched_dot(sqrt_quad.T, sqrt_quad.T) with pt.sum(sqrt_quad.T ** 2, axis=-1) - Computes squared norm per sample using modern PyTensor operations - Eliminates deprecation warnings and ensures future compatibility --- pymc/distributions/multivariate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f76a98546e..46b9165645 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2124,8 +2124,8 @@ def logp(value, rng, size, mu, sigma, *covs): sqrt_quad = sqrt_quad / pt.sqrt(eigs[:, None]) logdet = pt.sum(pt.log(eigs)) - # Square each sample - quad = pt.batched_dot(sqrt_quad.T, sqrt_quad.T) + # Square each sample - compute squared norm for each sample + quad = pt.sum(sqrt_quad.T ** 2, axis=-1) if onedim: quad = quad[0] From 84f3fd1926dad201fe353dccd960c53506daae0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:44:43 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc/distributions/multivariate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 46b9165645..9435b40fa7 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2125,7 +2125,7 @@ def logp(value, rng, size, mu, sigma, *covs): logdet = pt.sum(pt.log(eigs)) # Square each sample - compute squared norm for each sample - quad = pt.sum(sqrt_quad.T ** 2, axis=-1) + quad = pt.sum(sqrt_quad.T**2, axis=-1) if onedim: quad = quad[0]