11import logging
2+ import warnings
23
34from collections .abc import Callable , Sequence
45from typing import Any , Literal
1415from pymc .model .transform .optimization import freeze_dims_and_data
1516from pymc .util import RandomState
1617from pytensor import Variable , graph_replace
17- from pytensor .compile import get_mode
1818from rich .box import SIMPLE_HEAD
1919from rich .console import Console
2020from rich .table import Table
@@ -99,6 +99,13 @@ class PyMCStateSpace:
9999 compute the observation errors. If False, these errors are deterministically zero; if True, they are sampled
100100 from a multivariate normal.
101101
102+ mode: str or Mode, optional
103+ Pytensor compile mode, used in auxiliary sampling methods such as ``sample_conditional_posterior`` and
104+ ``forecast``. The mode does **not** effect calls to ``pm.sample``.
105+
106+ Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
107+ to all sampling methods.
108+
102109 Notes
103110 -----
104111 Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
@@ -221,8 +228,8 @@ def __init__(
221228 filter_type : str = "standard" ,
222229 verbose : bool = True ,
223230 measurement_error : bool = False ,
231+ mode : str | None = None ,
224232 ):
225- self ._fit_mode : str | None = None
226233 self ._fit_coords : dict [str , Sequence [str ]] | None = None
227234 self ._fit_dims : dict [str , Sequence [str ]] | None = None
228235 self ._fit_data : pt .TensorVariable | None = None
@@ -237,6 +244,7 @@ def __init__(
237244 self .k_states = k_states
238245 self .k_posdef = k_posdef
239246 self .measurement_error = measurement_error
247+ self .mode = mode
240248
241249 # All models contain a state space representation and a Kalman filter
242250 self .ssm = PytensorRepresentation (k_endog , k_states , k_posdef )
@@ -819,11 +827,11 @@ def build_statespace_graph(
819827 self ,
820828 data : np .ndarray | pd .DataFrame | pt .TensorVariable ,
821829 register_data : bool = True ,
822- mode : str | None = None ,
823830 missing_fill_value : float | None = None ,
824831 cov_jitter : float | None = JITTER_DEFAULT ,
825832 mvn_method : Literal ["cholesky" , "eigh" , "svd" ] = "svd" ,
826833 save_kalman_filter_outputs_in_idata : bool = False ,
834+ mode : str | None = None ,
827835 ) -> None :
828836 """
829837 Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
@@ -877,7 +885,25 @@ def build_statespace_graph(
877885 save_kalman_filter_outputs_in_idata: bool, optional, default=False
878886 If True, Kalman Filter outputs will be saved in the model as deterministics. Useful for debugging, but
879887 should not be necessary for the majority of users.
888+
889+ mode: str, optional
890+ Pytensor mode to use when compiling the graph. This will be saved as a model attribute and used when
891+ compiling sampling functions (e.g. ``sample_conditional_prior``).
892+
893+ .. deprecated:: 0.2.5
894+ The `mode` argument is deprecated and will be removed in a future version. Pass ``mode`` to the
895+ model constructor, or manually specify ``compile_kwargs`` in sampling functions instead.
896+
880897 """
898+ if mode is not None :
899+ warnings .warn (
900+ "The `mode` argument is deprecated and will be removed in a future version. "
901+ "Pass `mode` to the model constructor, or manually specify `compile_kwargs` in sampling functions"
902+ " instead." ,
903+ DeprecationWarning ,
904+ )
905+ self .mode = mode
906+
881907 pm_mod = modelcontext (None )
882908
883909 self ._insert_random_variables ()
@@ -898,7 +924,6 @@ def build_statespace_graph(
898924 filter_outputs = self .kalman_filter .build_graph (
899925 pt .as_tensor_variable (data ),
900926 * self .unpack_statespace (),
901- mode = mode ,
902927 missing_fill_value = missing_fill_value ,
903928 cov_jitter = cov_jitter ,
904929 )
@@ -909,7 +934,7 @@ def build_statespace_graph(
909934 filtered_covariances , predicted_covariances , observed_covariances = covs
910935 if save_kalman_filter_outputs_in_idata :
911936 smooth_states , smooth_covariances = self ._build_smoother_graph (
912- filtered_states , filtered_covariances , self .unpack_statespace (), mode = mode
937+ filtered_states , filtered_covariances , self .unpack_statespace ()
913938 )
914939 all_kf_outputs = [* states , smooth_states , * covs , smooth_covariances ]
915940 self ._register_kalman_filter_outputs_with_pymc_model (all_kf_outputs )
@@ -929,7 +954,6 @@ def build_statespace_graph(
929954
930955 self ._fit_coords = pm_mod .coords .copy ()
931956 self ._fit_dims = pm_mod .named_vars_to_dims .copy ()
932- self ._fit_mode = mode
933957
934958 def _build_smoother_graph (
935959 self ,
@@ -974,7 +998,7 @@ def _build_smoother_graph(
974998 * _ , T , Z , R , H , Q = matrices
975999
9761000 smooth_states , smooth_covariances = self .kalman_smoother .build_graph (
977- T , R , Q , filtered_states , filtered_covariances , mode = mode , cov_jitter = cov_jitter
1001+ T , R , Q , filtered_states , filtered_covariances , cov_jitter = cov_jitter
9781002 )
9791003 smooth_states .name = "smooth_states"
9801004 smooth_covariances .name = "smooth_covariances"
@@ -1092,7 +1116,6 @@ def _kalman_filter_outputs_from_dummy_graph(
10921116 R ,
10931117 H ,
10941118 Q ,
1095- mode = self ._fit_mode ,
10961119 )
10971120
10981121 filter_outputs .pop (- 1 )
@@ -1102,7 +1125,7 @@ def _kalman_filter_outputs_from_dummy_graph(
11021125 filtered_covariances , predicted_covariances , _ = covariances
11031126
11041127 [smoothed_states , smoothed_covariances ] = self .kalman_smoother .build_graph (
1105- T , R , Q , filtered_states , filtered_covariances , mode = self . _fit_mode
1128+ T , R , Q , filtered_states , filtered_covariances
11061129 )
11071130
11081131 grouped_outputs = [
@@ -1164,6 +1187,9 @@ def _sample_conditional(
11641187 _verify_group (group )
11651188 group_idata = getattr (idata , group )
11661189
1190+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1191+ compile_kwargs .setdefault ("mode" , self .mode )
1192+
11671193 with pm .Model (coords = self ._fit_coords ) as forward_model :
11681194 (
11691195 [
@@ -1229,8 +1255,8 @@ def _sample_conditional(
12291255 for name in FILTER_OUTPUT_TYPES
12301256 for suffix in ["" , "_observed" ]
12311257 ],
1232- compile_kwargs = {"mode" : get_mode (self ._fit_mode )},
12331258 random_seed = random_seed ,
1259+ compile_kwargs = compile_kwargs ,
12341260 ** kwargs ,
12351261 )
12361262
@@ -1296,6 +1322,10 @@ def _sample_unconditional(
12961322 the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
12971323 """
12981324 _verify_group (group )
1325+
1326+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1327+ compile_kwargs .setdefault ("mode" , self .mode )
1328+
12991329 group_idata = getattr (idata , group )
13001330 dims = None
13011331 temp_coords = self ._fit_coords .copy ()
@@ -1338,7 +1368,6 @@ def _sample_unconditional(
13381368 * matrices ,
13391369 steps = steps ,
13401370 dims = dims ,
1341- mode = self ._fit_mode ,
13421371 method = mvn_method ,
13431372 sequence_names = self .kalman_filter .seq_names ,
13441373 k_endog = self .k_endog ,
@@ -1354,8 +1383,8 @@ def _sample_unconditional(
13541383 idata_unconditional = pm .sample_posterior_predictive (
13551384 group_idata ,
13561385 var_names = [f"{ group } _latent" , f"{ group } _observed" ],
1357- compile_kwargs = {"mode" : self ._fit_mode },
13581386 random_seed = random_seed ,
1387+ compile_kwargs = compile_kwargs ,
13591388 ** kwargs ,
13601389 )
13611390
@@ -1583,7 +1612,7 @@ def sample_unconditional_posterior(
15831612 )
15841613
15851614 def sample_statespace_matrices (
1586- self , idata , matrix_names : str | list [str ] | None , group : str = "posterior"
1615+ self , idata , matrix_names : str | list [str ] | None , group : str = "posterior" , ** kwargs
15871616 ):
15881617 """
15891618 Draw samples of requested statespace matrices from provided idata
@@ -1600,12 +1629,18 @@ def sample_statespace_matrices(
16001629 group: str, one of "posterior" or "prior"
16011630 Whether to sample from priors or posteriors
16021631
1632+ kwargs:
1633+ Additional keyword arguments are passed to ``pymc.sample_posterior_predictive``
1634+
16031635 Returns
16041636 -------
16051637 idata_matrices: az.InterenceData
16061638 """
16071639 _verify_group (group )
16081640
1641+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
1642+ compile_kwargs .setdefault ("mode" , self .mode )
1643+
16091644 if matrix_names is None :
16101645 matrix_names = MATRIX_NAMES
16111646 elif isinstance (matrix_names , str ):
@@ -1636,8 +1671,9 @@ def sample_statespace_matrices(
16361671 matrix_idata = pm .sample_posterior_predictive (
16371672 idata if group == "posterior" else idata .prior ,
16381673 var_names = matrix_names ,
1639- compile_kwargs = {"mode" : self ._fit_mode },
16401674 extend_inferencedata = False ,
1675+ compile_kwargs = compile_kwargs ,
1676+ ** kwargs ,
16411677 )
16421678
16431679 return matrix_idata
@@ -2106,6 +2142,10 @@ def forecast(
21062142 filter_time_dim = TIME_DIM
21072143
21082144 _validate_filter_arg (filter_output )
2145+
2146+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
2147+ compile_kwargs .setdefault ("mode" , self .mode )
2148+
21092149 time_index = self ._get_fit_time_index ()
21102150
21112151 if start is None and verbose :
@@ -2192,7 +2232,6 @@ def forecast(
21922232 * matrices ,
21932233 steps = len (forecast_index ),
21942234 dims = dims ,
2195- mode = self ._fit_mode ,
21962235 sequence_names = self .kalman_filter .seq_names ,
21972236 k_endog = self .k_endog ,
21982237 append_x0 = False ,
@@ -2208,8 +2247,8 @@ def forecast(
22082247 idata_forecast = pm .sample_posterior_predictive (
22092248 idata ,
22102249 var_names = ["forecast_latent" , "forecast_observed" ],
2211- compile_kwargs = {"mode" : self ._fit_mode },
22122250 random_seed = random_seed ,
2251+ compile_kwargs = compile_kwargs ,
22132252 ** kwargs ,
22142253 )
22152254
@@ -2297,6 +2336,9 @@ def impulse_response_function(
22972336 n_options = sum (x is not None for x in options )
22982337 Q = None # No covariance matrix needed if a trajectory is provided. Will be overwritten later if needed.
22992338
2339+ compile_kwargs = kwargs .pop ("compile_kwargs" , {})
2340+ compile_kwargs .setdefault ("mode" , self .mode )
2341+
23002342 if n_options > 1 :
23012343 raise ValueError ("Specify exactly 0 or 1 of shock_size, shock_cov, or shock_trajectory" )
23022344 elif n_options == 1 :
@@ -2368,29 +2410,15 @@ def irf_step(shock, x, c, T, R):
23682410 non_sequences = [c , T , R ],
23692411 n_steps = n_steps ,
23702412 strict = True ,
2371- mode = self ._fit_mode ,
23722413 )
23732414
23742415 pm .Deterministic ("irf" , irf , dims = [TIME_DIM , ALL_STATE_DIM ])
23752416
2376- compile_kwargs = kwargs .get ("compile_kwargs" , {})
2377- if "mode" not in compile_kwargs .keys ():
2378- compile_kwargs = {"mode" : self ._fit_mode }
2379- else :
2380- mode = compile_kwargs .get ("mode" )
2381- if mode is not None and mode != self ._fit_mode :
2382- raise ValueError (
2383- f"User provided compile mode ({ mode } ) does not match the compile mode used to "
2384- f"construct the model ({ self ._fit_mode } )."
2385- )
2386-
2387- compile_kwargs .update ({"mode" : self ._fit_mode })
2388-
23892417 irf_idata = pm .sample_posterior_predictive (
23902418 idata ,
23912419 var_names = ["irf" ],
2392- compile_kwargs = compile_kwargs ,
23932420 random_seed = random_seed ,
2421+ compile_kwargs = compile_kwargs ,
23942422 ** kwargs ,
23952423 )
23962424
0 commit comments