1010from pytensor .graph .basic import Node
1111
1212floatX = pytensor .config .floatX
13+ COV_ZERO_TOL = 0
1314
1415lgss_shape_message = (
1516 "The LinearGaussianStateSpace distribution needs shape information to be constructed. "
@@ -157,8 +158,11 @@ def step_fn(*args):
157158 middle_rng , a_innovation = pm .MvNormal .dist (mu = 0 , cov = Q , rng = rng ).owner .outputs
158159 next_rng , y_innovation = pm .MvNormal .dist (mu = 0 , cov = H , rng = middle_rng ).owner .outputs
159160
160- a_next = c + T @ a + R @ a_innovation
161- y_next = d + Z @ a_next + y_innovation
161+ a_mu = c + T @ a
162+ a_next = pt .switch (pt .all (pt .le (Q , COV_ZERO_TOL )), a_mu , a_mu + R @ a_innovation )
163+
164+ y_mu = d + Z @ a_next
165+ y_next = pt .switch (pt .all (pt .le (H , COV_ZERO_TOL )), y_mu , y_mu + y_innovation )
162166
163167 next_state = pt .concatenate ([a_next , y_next ], axis = 0 )
164168
@@ -168,7 +172,11 @@ def step_fn(*args):
168172 Z_init = Z_ if Z_ in non_sequences else Z_ [0 ]
169173 H_init = H_ if H_ in non_sequences else H_ [0 ]
170174
171- init_y_ = pm .MvNormal .dist (Z_init @ init_x_ , H_init , rng = rng )
175+ init_y_ = pt .switch (
176+ pt .all (pt .le (H_init , COV_ZERO_TOL )),
177+ Z_init @ init_x_ ,
178+ pm .MvNormal .dist (Z_init @ init_x_ , H_init , rng = rng ),
179+ )
172180 init_dist_ = pt .concatenate ([init_x_ , init_y_ ], axis = 0 )
173181
174182 statespace , updates = pytensor .scan (
@@ -216,6 +224,7 @@ def __new__(
216224 steps = None ,
217225 mode = None ,
218226 sequence_names = None ,
227+ k_endog = None ,
219228 ** kwargs ,
220229 ):
221230 dims = kwargs .pop ("dims" , None )
@@ -239,11 +248,29 @@ def __new__(
239248 sequence_names = sequence_names ,
240249 ** kwargs ,
241250 )
242-
243251 k_states = T .type .shape [0 ]
244252
245- latent_states = latent_obs_combined [..., :k_states ]
246- obs_states = latent_obs_combined [..., k_states :]
253+ if k_endog is None and k_states is None :
254+ raise ValueError ("Could not infer number of observed states, explicitly pass k_endog." )
255+ if k_endog is not None and k_states is not None :
256+ total_shape = latent_obs_combined .type .shape [- 1 ]
257+ inferred_endog = total_shape - k_states
258+ if inferred_endog != k_endog :
259+ raise ValueError (
260+ f"Inferred k_endog does not agree with provided value ({ inferred_endog } != { k_endog } ). "
261+ f"It is not necessary to provide k_endog when the value can be inferred."
262+ )
263+ latent_slice = slice (None , - k_endog )
264+ obs_slice = slice (- k_endog , None )
265+ elif k_endog is None :
266+ latent_slice = slice (None , k_states )
267+ obs_slice = slice (k_states , None )
268+ else :
269+ latent_slice = slice (None , - k_endog )
270+ obs_slice = slice (- k_endog , None )
271+
272+ latent_states = latent_obs_combined [..., latent_slice ]
273+ obs_states = latent_obs_combined [..., obs_slice ]
247274
248275 latent_states = pm .Deterministic (f"{ name } _latent" , latent_states , dims = latent_dims )
249276 obs_states = pm .Deterministic (f"{ name } _observed" , obs_states , dims = obs_dims )
0 commit comments