66from pymc import intX
77from pymc .distributions .dist_math import check_parameters
88from pymc .distributions .distribution import Continuous , SymbolicRandomVariable
9- from pymc .distributions .multivariate import MvNormal
109from pymc .distributions .shape_utils import get_support_shape_1d
1110from pymc .logprob .abstract import _logprob
1211from pytensor .graph .basic import Node
13- from pytensor .tensor .random .basic import MvNormalRV
1412
1513floatX = pytensor .config .floatX
1614COV_ZERO_TOL = 0
@@ -49,23 +47,6 @@ def make_signature(sequence_names):
4947 return f"{ signature } ,[rng]->[rng],({ time } ,{ state_and_obs } )"
5048
5149
52- class MvNormalSVDRV (MvNormalRV ):
53- name = "multivariate_normal"
54- signature = "(n),(n,n)->(n)"
55- dtype = "floatX"
56- _print_name = ("MultivariateNormal" , "\\ operatorname{MultivariateNormal}" )
57-
58-
59- class MvNormalSVD (MvNormal ):
60- """Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
61-
62- A JAX MvNormal robust to low-rank covariance matrices
63- """
64-
65- # TODO: Remove this entirely on next PyMC release; method will be exposed directly in MvNormal
66- rv_op = MvNormalSVDRV (method = "svd" )
67-
68-
6950class LinearGaussianStateSpaceRV (SymbolicRandomVariable ):
7051 default_output = 1
7152 _print_name = ("LinearGuassianStateSpace" , "\\ operatorname{LinearGuassianStateSpace}" )
@@ -223,8 +204,12 @@ def step_fn(*args):
223204 k = T .shape [0 ]
224205 a = state [:k ]
225206
226- middle_rng , a_innovation = MvNormalSVD .dist (mu = 0 , cov = Q , rng = rng ).owner .outputs
227- next_rng , y_innovation = MvNormalSVD .dist (mu = 0 , cov = H , rng = middle_rng ).owner .outputs
207+ middle_rng , a_innovation = pm .MvNormal .dist (
208+ mu = 0 , cov = Q , rng = rng , method = "svd"
209+ ).owner .outputs
210+ next_rng , y_innovation = pm .MvNormal .dist (
211+ mu = 0 , cov = H , rng = middle_rng , method = "svd"
212+ ).owner .outputs
228213
229214 a_mu = c + T @ a
230215 a_next = a_mu + R @ a_innovation
@@ -239,8 +224,8 @@ def step_fn(*args):
239224 Z_init = Z_ if Z_ in non_sequences else Z_ [0 ]
240225 H_init = H_ if H_ in non_sequences else H_ [0 ]
241226
242- init_x_ = MvNormalSVD . dist (a0_ , P0_ , rng = rng )
243- init_y_ = MvNormalSVD . dist (Z_init @ init_x_ , H_init , rng = rng )
227+ init_x_ = pm . MvNormal . dist (a0_ , P0_ , rng = rng , method = "svd" )
228+ init_y_ = pm . MvNormal . dist (Z_init @ init_x_ , H_init , rng = rng , method = "svd" )
244229
245230 init_dist_ = pt .concatenate ([init_x_ , init_y_ ], axis = 0 )
246231
@@ -400,7 +385,7 @@ def rv_op(cls, mus, covs, logp, size=None):
400385 rng = pytensor .shared (np .random .default_rng ())
401386
402387 def step (mu , cov , rng ):
403- new_rng , mvn = MvNormalSVD . dist (mu = mu , cov = cov , rng = rng ).owner .outputs
388+ new_rng , mvn = pm . MvNormal . dist (mu = mu , cov = cov , rng = rng , method = "svd" ).owner .outputs
404389 return mvn , {rng : new_rng }
405390
406391 mvn_seq , updates = pytensor .scan (
0 commit comments