From f50af8a51e50e66844a2fb2cae155069ad5fa842 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 29 Oct 2025 02:09:07 +0100 Subject: [PATCH 01/11] Approximation's sample method uses model contexts This also deprecates `self.model` --- pymc/variational/opvi.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 3cd5cc3dcf..0b4ee49e88 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -70,7 +70,7 @@ from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection from pymc.initial_point import make_initial_point_fn -from pymc.model import modelcontext +from pymc.model import Model, modelcontext from pymc.pytensorf import ( SeedSequenceSeed, compile, @@ -1246,12 +1246,21 @@ def __init__(self, groups, model=None): else: rest.__init_group__(unseen_free_RVs) self.groups.append(rest) - self.model = model + self._model = model @property def has_logq(self): return all(self.collect("has_logq")) + @property + def model(self): + warnings.warn( + "`model` field is deprecated and will be removed in future versions. Use " + "a model context instead.", + DeprecationWarning, + ) + return self._model + def collect(self, item): return [getattr(g, item) for g in self.groups] @@ -1538,9 +1547,12 @@ def vars_names(vs): return found @node_property - def sample_dict_fn(self): + def sample_dict_fn(self, model=None): s = pt.iscalar() - names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs] + + model = modelcontext(model) + + names = [model.rvs_to_values[v].name for v in model.free_RVs] sampled = [self.rslice(name) for name in names] sampled = self.set_size_and_deterministic(sampled, s, 0) sample_fn = compile([s], sampled) @@ -1556,7 +1568,13 @@ def inner(draws=100, *, random_seed: SeedSequenceSeed = None): return inner def sample( - self, draws=500, *, random_seed: RandomState = None, return_inferencedata=True, **kwargs + self, + draws=500, + *, + model: Model | None = None, + random_seed: RandomState = None, + return_inferencedata=True, + **kwargs, ): """Draw samples from variational posterior. @@ -1564,6 +1582,8 @@ def sample( ---------- draws : int Number of random samples. + model : Model (optional if in ``with`` context + Model to be used to generate samples. random_seed : int, RandomState or Generator, optional Seed for the random number generator. return_inferencedata : bool @@ -1577,16 +1597,18 @@ def sample( # TODO: add tests for include_transformed case kwargs["log_likelihood"] = False + model = modelcontext(model) + if random_seed is not None: (random_seed,) = _get_seeds_per_chain(random_seed, 1) - samples: dict = self.sample_dict_fn(draws, random_seed=random_seed) + samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed) points = ( {name: np.asarray(records[i]) for name, records in samples.items()} for i in range(draws) ) trace = NDArray( - model=self.model, + model=model, test_point={name: records[0] for name, records in samples.items()}, ) try: @@ -1600,7 +1622,7 @@ def sample( if not return_inferencedata: return multi_trace else: - return pm.to_inference_data(multi_trace, model=self.model, **kwargs) + return pm.to_inference_data(multi_trace, model=model, **kwargs) @property def ndim(self): From a044e76743a796b623115cec4e343307da9b9336 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Wed, 29 Oct 2025 02:46:46 +0100 Subject: [PATCH 02/11] Update rslice and sample_node --- pymc/variational/opvi.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 0b4ee49e88..4350b610f5 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1491,12 +1491,14 @@ def get_optimization_replacements(self, s, d): return repl @pytensor.config.change_flags(compute_test_value="off") - def sample_node(self, node, size=None, deterministic=False, more_replacements=None): + def sample_node(self, node, model=None, size=None, deterministic=False, more_replacements=None): """Sample given node or nodes over shared posterior. Parameters ---------- node: PyTensor Variables (or PyTensor expressions) + model : Model (optional if in ``with`` context + Model to be used to generate samples. size: None or scalar number of samples more_replacements: `dict` @@ -1510,11 +1512,14 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No sampled node(s) with replacements """ node_in = node + + model = modelcontext(model) + if more_replacements: node = graph_replace(node, more_replacements, strict=False) if not isinstance(node, list | tuple): node = [node] - node = self.model.replace_rvs_by_values(node) + node = model.replace_rvs_by_values(node) if not isinstance(node_in, list | tuple): node = node[0] if size is None: @@ -1525,14 +1530,14 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No try_to_set_test_value(node_in, node_out, size) return node_out - def rslice(self, name): + def rslice(self, name, model): """*Dev* - vectorized sampling for named random variable without call to `pytensor.scan`. This node still needs :func:`set_size_and_deterministic` to be evaluated. """ def vars_names(vs): - return {self.model.rvs_to_values[v].name for v in vs} + return {model.rvs_to_values[v].name for v in vs} for vars_, random, ordering in zip( self.collect("group"), self.symbolic_randoms, self.collect("ordering") @@ -1553,7 +1558,7 @@ def sample_dict_fn(self, model=None): model = modelcontext(model) names = [model.rvs_to_values[v].name for v in model.free_RVs] - sampled = [self.rslice(name) for name in names] + sampled = [self.rslice(name, model) for name in names] sampled = self.set_size_and_deterministic(sampled, s, 0) sample_fn = compile([s], sampled) rng_nodes = find_rng_nodes(sampled) From adcdf90537cb69401952a6b129c5e1125515777e Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Fri, 31 Oct 2025 02:29:23 +0100 Subject: [PATCH 03/11] Make tests pass --- pymc/sampling/mcmc.py | 6 +++--- pymc/variational/operators.py | 2 +- pymc/variational/opvi.py | 28 +++++++++++++--------------- tests/variational/test_inference.py | 13 +++++++------ tests/variational/test_opvi.py | 6 +++--- 5 files changed, 27 insertions(+), 28 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index de341c68cd..3e2dea607f 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1648,7 +1648,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray: compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( - draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + draws=chains, model=model, random_seed=random_seed_list[0], return_inferencedata=False ) initial_points = [approx_sample[i] for i in range(chains)] std_apoint = approx.std.eval() @@ -1672,7 +1672,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray: compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( - draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + draws=chains, model=model, random_seed=random_seed_list[0], return_inferencedata=False ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 @@ -1690,7 +1690,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray: compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( - draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + draws=chains, model=model, random_seed=random_seed_list[0], return_inferencedata=False ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 diff --git a/pymc/variational/operators.py b/pymc/variational/operators.py index 502fe13ab9..6981dedc53 100644 --- a/pymc/variational/operators.py +++ b/pymc/variational/operators.py @@ -142,7 +142,7 @@ def __init__(self, approx, temperature=1): def apply(self, f): # f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.)) - if _known_scan_ignored_inputs([self.approx.model.logp()]): + if _known_scan_ignored_inputs([self.approx._model.logp()]): raise NotImplementedInference( "SVGD does not currently support Minibatch or Simulator RV" ) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 4350b610f5..fe712c248f 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1289,7 +1289,7 @@ def symbolic_normalizing_constant(self): obs.owner.inputs[1:], constant_fold([obs.owner.inputs[0].shape], raise_not_constant=False), ) - for obs in self.model.observed_RVs + for obs in self._model.observed_RVs if isinstance(obs.owner.op, MinibatchRandomVariable) ] ) @@ -1315,7 +1315,7 @@ def logq_norm(self): def _sized_symbolic_varlogp_and_datalogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" varlogp_s, datalogp_s = self.symbolic_sample_over_posterior( - [self.model.varlogp, self.model.datalogp] + [self._model.varlogp, self._model.datalogp] ) return varlogp_s, datalogp_s # both shape (s,) @@ -1352,7 +1352,7 @@ def datalogp(self): @node_property def _single_symbolic_varlogp_and_datalogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" - varlogp, datalogp = self.symbolic_single_sample([self.model.varlogp, self.model.datalogp]) + varlogp, datalogp = self.symbolic_single_sample([self._model.varlogp, self._model.datalogp]) return varlogp, datalogp @node_property @@ -1491,14 +1491,12 @@ def get_optimization_replacements(self, s, d): return repl @pytensor.config.change_flags(compute_test_value="off") - def sample_node(self, node, model=None, size=None, deterministic=False, more_replacements=None): + def sample_node(self, node, size=None, deterministic=False, more_replacements=None): """Sample given node or nodes over shared posterior. Parameters ---------- node: PyTensor Variables (or PyTensor expressions) - model : Model (optional if in ``with`` context - Model to be used to generate samples. size: None or scalar number of samples more_replacements: `dict` @@ -1513,7 +1511,7 @@ def sample_node(self, node, model=None, size=None, deterministic=False, more_rep """ node_in = node - model = modelcontext(model) + model = self._model if more_replacements: node = graph_replace(node, more_replacements, strict=False) @@ -1552,18 +1550,18 @@ def vars_names(vs): return found @node_property - def sample_dict_fn(self, model=None): + def sample_dict_fn(self): s = pt.iscalar() - model = modelcontext(model) + def inner(draws=100, *, model=None, random_seed: SeedSequenceSeed = None): + model = modelcontext(model) - names = [model.rvs_to_values[v].name for v in model.free_RVs] - sampled = [self.rslice(name, model) for name in names] - sampled = self.set_size_and_deterministic(sampled, s, 0) - sample_fn = compile([s], sampled) - rng_nodes = find_rng_nodes(sampled) + names = [model.rvs_to_values[v].name for v in model.free_RVs] + sampled = [self.rslice(name, model) for name in names] + sampled = self.set_size_and_deterministic(sampled, s, 0) + sample_fn = compile([s], sampled) + rng_nodes = find_rng_nodes(sampled) - def inner(draws=100, *, random_seed: SeedSequenceSeed = None): if random_seed is not None: reseed_rngs(rng_nodes, random_seed) _samples = sample_fn(draws) diff --git a/tests/variational/test_inference.py b/tests/variational/test_inference.py index 10b824179d..8effb4ed91 100644 --- a/tests/variational/test_inference.py +++ b/tests/variational/test_inference.py @@ -41,7 +41,7 @@ def test_fit_with_nans(score): mean = inp * coef pm.Normal("y", mean, 0.1, observed=y) with pytest.raises(FloatingPointError) as e: - advi = pm.fit(100, score=score, obj_optimizer=pm.adam(learning_rate=float("nan"))) + pm.fit(100, score=score, obj_optimizer=pm.adam(learning_rate=float("nan"))) @pytest.fixture(scope="module", params=[True, False], ids=["mini", "full"]) @@ -174,8 +174,8 @@ def fit_kwargs(inference, use_minibatch): return _select[(type(inference), key)] -def test_fit_oo(inference, fit_kwargs, simple_model_data): - trace = inference.fit(**fit_kwargs).sample(10000) +def test_fit_oo(simple_model, inference, fit_kwargs, simple_model_data): + trace = inference.fit(**fit_kwargs).sample(10000, model=simple_model) mu_post = simple_model_data["mu_post"] d = simple_model_data["d"] np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05) @@ -202,7 +202,8 @@ def test_fit_start(inference_spec, simple_model): inference = inference_spec(**kw) try: - trace = inference.fit(n=0).sample(10000) + with simple_model: + trace = inference.fit(n=0).sample(10000) except NotImplementedInference as e: pytest.skip(str(e)) @@ -269,7 +270,7 @@ def binomial_model_inference(binomial_model, inference_spec): def test_replacements(binomial_model_inference): d = pytensor.shared(1) approx = binomial_model_inference.approx - p = approx.model.p + p = approx._model.p p_t = p**3 p_s = approx.sample_node(p_t) assert not any( @@ -309,7 +310,7 @@ def test_sample_replacements(binomial_model_inference): i = pt.iscalar() i.tag.test_value = 1 approx = binomial_model_inference.approx - p = approx.model.p + p = approx._model.p p_t = p**3 p_s = approx.sample_node(p_t, size=100) if pytensor.config.compute_test_value != "off": diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index 0f40572f72..ffbfbbf090 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -213,7 +213,7 @@ def test_pickle_approx(three_var_approx): dump = cloudpickle.dumps(three_var_approx) new = cloudpickle.loads(dump) - assert new.sample(1) + assert new.sample(1, model=new._model) def test_pickle_single_group(three_var_approx_single_group_mf): @@ -221,11 +221,11 @@ def test_pickle_single_group(three_var_approx_single_group_mf): dump = cloudpickle.dumps(three_var_approx_single_group_mf) new = cloudpickle.loads(dump) - assert new.sample(1) + assert new.sample(1, model=new._model) def test_sample_simple(three_var_approx): - trace = three_var_approx.sample(100, return_inferencedata=False) + trace = three_var_approx.sample(100, model=three_var_approx._model, return_inferencedata=False) assert set(trace.varnames) == {"one", "one_log__", "three", "two"} assert len(trace) == 100 assert trace[0]["one"].shape == (10, 2) From 6a6b5912f2343795da0bd26ddfaa1fbd542d8147 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Fri, 31 Oct 2025 15:26:12 +0100 Subject: [PATCH 04/11] Vibe coded --- pymc/variational/approximations.py | 24 +- pymc/variational/operators.py | 4 +- pymc/variational/opvi.py | 610 ++++++++++++++++++++++++++-- tests/variational/test_inference.py | 148 +++---- tests/variational/test_opvi.py | 101 +++-- 5 files changed, 722 insertions(+), 165 deletions(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 29b7093108..6f012333a9 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -85,6 +85,13 @@ def create_shared_params(self, start=None, start_sigma=None): # by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or # future code changes that break this assumption. start = self._prepare_start(start) + # Ensure start is a 1D array and matches ddim + start = np.asarray(start).flatten() + if start.size != self.ddim: + raise ValueError( + f"Start array size mismatch: got {start.size}, expected {self.ddim}. " + f"Start shape: {start.shape if hasattr(start, 'shape') else 'unknown'}" + ) rho1 = np.zeros((self.ddim,)) if start_sigma is not None: @@ -139,6 +146,13 @@ def __init_group__(self, group): def create_shared_params(self, start=None): start = self._prepare_start(start) + # Ensure start is a 1D array and matches ddim + start = np.asarray(start).flatten() + if start.size != self.ddim: + raise ValueError( + f"Start array size mismatch: got {start.size}, expected {self.ddim}. " + f"Start shape: {start.shape if hasattr(start, 'shape') else 'unknown'}" + ) n = self.ddim L_tril = np.eye(n)[np.tril_indices(n)].astype(pytensor.config.floatX) return {"mu": pytensor.shared(start, "mu"), "L_tril": pytensor.shared(L_tril, "L_tril")} @@ -233,6 +247,8 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None): return {"histogram": pytensor.shared(pm.floatX(histogram), "histogram")} def _check_trace(self): + from pymc.model import modelcontext + trace = self._kwargs.get("trace", None) if isinstance(trace, InferenceData): raise NotImplementedError( @@ -240,10 +256,10 @@ def _check_trace(self): " Pass `pm.sample(return_inferencedata=False)` to get a `MultiTrace` to use with `Empirical`." " Please help us to refactor: https://github.com/pymc-devs/pymc/issues/5884" ) - elif trace is not None and not all( - self.model.rvs_to_values[var].name in trace.varnames for var in self.group - ): - raise ValueError("trace has not all free RVs in the group") + elif trace is not None: + model = modelcontext(None) + if not all(model.rvs_to_values[var].name in trace.varnames for var in self.group): + raise ValueError("trace has not all free RVs in the group") def randidx(self, size=None): if size is None: diff --git a/pymc/variational/operators.py b/pymc/variational/operators.py index 6981dedc53..951f521d51 100644 --- a/pymc/variational/operators.py +++ b/pymc/variational/operators.py @@ -19,6 +19,7 @@ import pymc as pm +from pymc.model import modelcontext from pymc.variational import opvi from pymc.variational.opvi import ( NotImplementedInference, @@ -142,7 +143,8 @@ def __init__(self, approx, temperature=1): def apply(self, f): # f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.)) - if _known_scan_ignored_inputs([self.approx._model.logp()]): + model = modelcontext(None) + if _known_scan_ignored_inputs([model.logp()]): raise NotImplementedInference( "SVGD does not currently support Minibatch or Simulator RV" ) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index fe712c248f..78c6b3ee10 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -82,7 +82,6 @@ RandomState, WithMemoization, _get_seeds_per_chain, - locally_cachedmethod, makeiter, ) from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling @@ -145,17 +144,43 @@ def inner(*args, **kwargs): def node_property(f): """Wrap method to accessible tensor.""" + from collections import defaultdict + + from cachetools import LRUCache, cachedmethod + + def self_cache_fn(f_name): + def cf(self): + return self.__dict__.setdefault("_cache", defaultdict(lambda: LRUCache(128)))[f_name] + + return cf + + def cache_key_with_model(*args, **kwargs): + """Cache key that includes the current model context.""" + from pymc.util import hash_key + + try: + model = modelcontext(None) + # Include the model in the cache key to avoid using cached values + # from a different model context + return ( + *tuple(hash_key(*args, **kwargs)), + id(model), + ) + except TypeError: + # If no model context, just use regular hash_key + return tuple(hash_key(*args, **kwargs)) + if isinstance(f, str): def wrapper(fn): ff = append_name(f)(fn) f_ = pytensor.config.change_flags(compute_test_value="off")(ff) - return property(locally_cachedmethod(f_)) + return property(cachedmethod(self_cache_fn(f_.__name__), key=cache_key_with_model)(f_)) return wrapper else: f_ = pytensor.config.change_flags(compute_test_value="off")(f) - return property(locally_cachedmethod(f_)) + return property(cachedmethod(self_cache_fn(f_.__name__), key=cache_key_with_model)(f_)) @pytensor.config.change_flags(compute_test_value="ignore") @@ -497,7 +522,7 @@ def __init__(self, approx): varlogp_norm = property(lambda self: self.approx.varlogp_norm) datalogp_norm = property(lambda self: self.approx.datalogp_norm) logq_norm = property(lambda self: self.approx.logq_norm) - model = property(lambda self: self.approx.model) + model = property(lambda self: modelcontext(None)) def apply(self, f): # pragma: no cover R"""Operator itself. @@ -770,7 +795,6 @@ def __init__( self._vfam = vfam self.rng = np.random.RandomState(random_seed) model = modelcontext(model) - self.model = model self.group = group self.user_params = params self._user_params = None @@ -783,17 +807,42 @@ def __init__( self.__init_group__(self.group) def _prepare_start(self, start=None): + model = modelcontext(None) + # If start is already an array, we need to ensure it's flattened and matches ddim + if isinstance(start, np.ndarray): + start_flat = start.flatten() + if start_flat.size != self.ddim: + raise ValueError( + f"Mismatch in start array size: got {start_flat.size}, expected {self.ddim}. " + f"Start array shape: {start.shape}, flattened size: {start_flat.size}" + ) + return start_flat + # Otherwise, get initial point from model and filter by group variables ipfn = make_initial_point_fn( - model=self.model, + model=model, overrides=start, jitter_rvs={}, return_transformed=True, ) start = ipfn(self.rng.randint(2**30, dtype=np.int64)) - group_vars = {self.model.rvs_to_values[v].name for v in self.group} + group_vars = {model.rvs_to_values[v].name for v in self.group} start = {k: v for k, v in start.items() if k in group_vars} - start = DictToArrayBijection.map(start).data - return start + if not start: + raise ValueError( + f"No matching variables found in initial point for group variables: {group_vars}. " + f"Initial point keys: {list(ipfn(self.rng.randint(2**30, dtype=np.int64)).keys())}" + ) + start_raveled = DictToArrayBijection.map(start) + # Ensure we have a 1D array that matches self.ddim + start_data = start_raveled.data + expected_size = self.ddim + if start_data.size != expected_size: + raise ValueError( + f"Mismatch in start array size: got {start_data.size}, expected {expected_size}. " + f"Group variables: {group_vars}, Start dict keys: {list(start.keys())}, " + f"This might indicate an issue with the model context or group initialization." + ) + return start_data @classmethod def get_param_spec_for(cls, **kwargs): @@ -867,9 +916,38 @@ def __init_group__(self, group): """Initialize the group.""" if not group: raise GroupError("Got empty group") + model = modelcontext(None) + + # If self.group is already set (from unpickling), we might need to rebuild it + # to map old variables to new ones in the current model context + if self.group is not None: + # Check if any variables in self.group don't belong to the current model + # If so, rebuild the group by matching variable names + needs_rebuild = False + for var in self.group: + # Check if variable is in the current model's free_RVs + if var not in model.free_RVs: + needs_rebuild = True + break + + if needs_rebuild: + # Rebuild group by matching variable names + var_name_map = {var.name: var for var in model.free_RVs} + new_group = [] + for old_var in self.group: + if old_var.name in var_name_map: + new_group.append(var_name_map[old_var.name]) + else: + raise ValueError( + f"Variable '{old_var.name}' from unpickled group not found in current model. " + f"Available variables: {list(var_name_map.keys())}" + ) + self.group = new_group + if self.group is None: # delayed init self.group = group + self.symbolic_initial = self._initial_type( self.__class__.__name__ + "_symbolic_initial_tensor" ) @@ -878,15 +956,18 @@ def __init_group__(self, group): # so I have to to it by myself # 1) we need initial point (transformed space) - model_initial_point = self.model.initial_point(0) + model_initial_point = model.initial_point(0) # 2) we'll work with a single group, a subset of the model # here we need to create a mapping to replace value_vars with slices from the approximation + # Clear old replacements/ordering before rebuilding + self.replacements = collections.OrderedDict() + self.ordering = collections.OrderedDict() start_idx = 0 for var in self.group: if var.type.numpy_dtype.name in discrete_types: raise ParametrizationError(f"Discrete variables are not supported by VI: {var}") # 3) This is the way to infer shape and dtype of the variable - value_var = self.model.rvs_to_values[var] + value_var = model.rvs_to_values[var] test_var = model_initial_point[value_var.name] shape = test_var.shape size = test_var.size @@ -902,16 +983,40 @@ def __init_group__(self, group): ) start_idx += size + def __setstate__(self, state): + """Restore state after unpickling and clear cache.""" + super().__setstate__(state) + # Clear cached state after unpickling since cached values may reference + # variables from a different model context + self._clear_cached_state() + # Recreate _kwargs if it was deleted by _finalize_init, needed for rebuild + if not hasattr(self, "_kwargs"): + self._kwargs = {} + def _finalize_init(self): """*Dev* - clean up after init.""" del self._kwargs + def _clear_cached_state(self, *, reset_shared=False): + """Reset cached structures that depend on the current model context.""" + if hasattr(self, "_cache"): + del self._cache + self.replacements = collections.OrderedDict() + self.ordering = collections.OrderedDict() + for attr in ("symbolic_initial", "input"): + if attr in self.__dict__: + delattr(self, attr) + if reset_shared: + self.shared_params = None + @property def params_dict(self): # prefixed are correctly reshaped if self._user_params is not None: return self._user_params else: + if self.shared_params is None and self.group is not None: + _refresh_group_for_model(self, modelcontext(None)) return self.shared_params @property @@ -1166,12 +1271,13 @@ def var_to_data(self, shared: pt.TensorVariable) -> xarray.Dataset: """Take a flat 1-dimensional tensor variable and maps it to an xarray data set based on the information in `self.ordering`.""" # This is somewhat similar to `DictToArrayBijection.rmap`, which doesn't work here since we don't have # `RaveledVars` and need to take the information from `self.ordering` instead + model = modelcontext(None) shared_nda = shared.eval() result = {} for name, s, shape, dtype in self.ordering.values(): - dims = self.model.named_vars_to_dims.get(name, None) + dims = model.named_vars_to_dims.get(name, None) if dims is not None: - coords = {d: np.array(self.model.coords[d]) for d in dims} + coords = {d: np.array(model.coords[d]) for d in dims} else: coords = None values = shared_nda[s].reshape(shape).astype(dtype) @@ -1193,6 +1299,84 @@ def std_data(self) -> xarray.Dataset: group_for_short_name = Group.group_for_short_name +def _map_group_vars_to_model(group_vars, model): + if not group_vars: + return [] + var_name_map = {var.name: var for var in model.free_RVs} + mapped = [] + for var in group_vars: + if var in model.free_RVs: + mapped.append(var) + else: + mapped_var = var_name_map.get(var.name) + if mapped_var is not None: + mapped.append(mapped_var) + return mapped + + +def _refresh_group_for_model(group, model, group_vars=None): + if group_vars is None: + group_vars = group.group or [] + mapped_group = _map_group_vars_to_model(group_vars, model) + if mapped_group: + group_vars = mapped_group + if not group_vars: + group.group = group_vars + return group.group + if group.shared_params is None: + if not hasattr(group, "_kwargs"): + group._kwargs = {} + original_user_params = group.user_params + group._clear_cached_state(reset_shared=True) + group.user_params = None + group._user_params = None + group.group = None + group.__init_group__(list(group_vars)) + if original_user_params is not None: + group.user_params = original_user_params + else: + group.group = list(group_vars) + if "symbolic_initial" not in group.__dict__: + group.symbolic_initial = group._initial_type( + group.__class__.__name__ + "_symbolic_initial_tensor" + ) + if "input" not in group.__dict__: + group.input = group._input_type(group.__class__.__name__ + "_symbolic_input") + _rebuild_group_mappings(group, model) + return group.group + + +def _rebuild_group_mappings(group, model): + if not group.group: + group.replacements = collections.OrderedDict() + group.ordering = collections.OrderedDict() + return + model_initial_point = model.initial_point(0) + replacements = collections.OrderedDict() + ordering = collections.OrderedDict() + start_idx = 0 + for var in group.group: + if var.type.numpy_dtype.name in discrete_types: + raise ParametrizationError(f"Discrete variables are not supported by VI: {var}") + value_var = model.rvs_to_values[var] + test_var = model_initial_point[value_var.name] + shape = test_var.shape + size = test_var.size + dtype = test_var.dtype + vr = group.input[..., start_idx : start_idx + size].reshape(shape).astype(dtype) + vr.name = value_var.name + "_vi_replacement" + replacements[value_var] = vr + ordering[value_var.name] = ( + value_var.name, + slice(start_idx, start_idx + size), + shape, + dtype, + ) + start_idx += size + group.replacements = replacements + group.ordering = ordering + + class Approximation(WithMemoization): """**Wrapper for grouped approximations**. @@ -1219,6 +1403,16 @@ class Approximation(WithMemoization): :class:`Group` """ + def __setstate__(self, state): + """Restore state after unpickling and clear cache.""" + super().__setstate__(state) + # Clear cache after unpickling since cached values may reference + # variables from a different model context + # _cache is removed during pickling by WithMemoization.__getstate__, + # so it shouldn't exist after unpickling, but ensure it's deleted if it does + if hasattr(self, "_cache"): + del self._cache + def __init__(self, groups, model=None): self._scale_cost_to_minibatch = pytensor.shared(np.int8(1)) model = modelcontext(model) @@ -1234,9 +1428,10 @@ def __init__(self, groups, model=None): else: rest = g else: - if set(g.group) & seen: + final_group = _refresh_group_for_model(g, model) + if set(final_group) & seen: raise GroupError("Found duplicates in groups") - seen.update(g.group) + seen.update(final_group) self.groups.append(g) # List iteration to preserve order for reproducibility between runs unseen_free_RVs = [var for var in model.free_RVs if var not in seen] @@ -1244,9 +1439,8 @@ def __init__(self, groups, model=None): if rest is None: raise GroupError("No approximation is specified for the rest variables") else: - rest.__init_group__(unseen_free_RVs) + _refresh_group_for_model(rest, model, unseen_free_RVs) self.groups.append(rest) - self._model = model @property def has_logq(self): @@ -1259,9 +1453,18 @@ def model(self): "a model context instead.", DeprecationWarning, ) - return self._model + return modelcontext(None) + + def _ensure_groups_ready(self, model=None): + try: + model = modelcontext(model) + except TypeError: + return + for g in self.groups: + _refresh_group_for_model(g, model) def collect(self, item): + self._ensure_groups_ready() return [getattr(g, item) for g in self.groups] inputs = property(lambda self: self.collect("input")) @@ -1282,6 +1485,7 @@ def symbolic_normalizing_constant(self): Here the effect is controlled by `self.scale_cost_to_minibatch`. """ + model = modelcontext(None) t = pt.max( self.collect("symbolic_normalizing_constant") + [ @@ -1289,7 +1493,7 @@ def symbolic_normalizing_constant(self): obs.owner.inputs[1:], constant_fold([obs.owner.inputs[0].shape], raise_not_constant=False), ) - for obs in self._model.observed_RVs + for obs in model.observed_RVs if isinstance(obs.owner.op, MinibatchRandomVariable) ] ) @@ -1314,9 +1518,8 @@ def logq_norm(self): @node_property def _sized_symbolic_varlogp_and_datalogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" - varlogp_s, datalogp_s = self.symbolic_sample_over_posterior( - [self._model.varlogp, self._model.datalogp] - ) + model = modelcontext(None) + varlogp_s, datalogp_s = self.symbolic_sample_over_posterior([model.varlogp, model.datalogp]) return varlogp_s, datalogp_s # both shape (s,) @node_property @@ -1352,7 +1555,8 @@ def datalogp(self): @node_property def _single_symbolic_varlogp_and_datalogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" - varlogp, datalogp = self.symbolic_single_sample([self._model.varlogp, self._model.datalogp]) + model = modelcontext(None) + varlogp, datalogp = self.symbolic_single_sample([model.varlogp, model.datalogp]) return varlogp, datalogp @node_property @@ -1511,7 +1715,7 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No """ node_in = node - model = self._model + model = modelcontext(None) if more_replacements: node = graph_replace(node, more_replacements, strict=False) @@ -1528,20 +1732,16 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No try_to_set_test_value(node_in, node_out, size) return node_out - def rslice(self, name, model): + def rslice(self, name, model=None): """*Dev* - vectorized sampling for named random variable without call to `pytensor.scan`. This node still needs :func:`set_size_and_deterministic` to be evaluated. """ + model = modelcontext(model) - def vars_names(vs): - return {model.rvs_to_values[v].name for v in vs} - - for vars_, random, ordering in zip( - self.collect("group"), self.symbolic_randoms, self.collect("ordering") - ): - if name in vars_names(vars_): - name_, slc, shape, dtype = ordering[name] + for random, ordering in zip(self.symbolic_randoms, self.collect("ordering")): + if name in ordering: + _name, slc, shape, dtype = ordering[name] found = random[..., slc].reshape((random.shape[0], *shape)).astype(dtype) found.name = name + "_vi_random_slice" break @@ -1554,19 +1754,80 @@ def sample_dict_fn(self): s = pt.iscalar() def inner(draws=100, *, model=None, random_seed: SeedSequenceSeed = None): - model = modelcontext(model) + from pymc.sampling.forward import compile_forward_sampling_function - names = [model.rvs_to_values[v].name for v in model.free_RVs] - sampled = [self.rslice(name, model) for name in names] - sampled = self.set_size_and_deterministic(sampled, s, 0) - sample_fn = compile([s], sampled) - rng_nodes = find_rng_nodes(sampled) + model = modelcontext(model) - if random_seed is not None: - reseed_rngs(rng_nodes, random_seed) - _samples = sample_fn(draws) + # Get all variable names that exist in the approximation + approx_names = set() + for ordering in self.collect("ordering"): + approx_names.update(ordering.keys()) + + # Get all variable names from the model + model_names = {model.rvs_to_values[v].name: v for v in model.free_RVs} + + # Separate variables into those in approximation and those not + approx_var_names = sorted(approx_names & set(model_names.keys())) + forward_var_names = sorted(set(model_names.keys()) - approx_names) + + # Sample variables from approximation + all_samples = {} + if approx_var_names: + sampled = [self.rslice(name, model) for name in approx_var_names] + sampled = self.set_size_and_deterministic(sampled, s, 0) + sample_fn = compile([s], sampled) + rng_nodes = find_rng_nodes(sampled) + + if random_seed is not None: + reseed_rngs(rng_nodes, random_seed) + _samples = sample_fn(draws) + all_samples.update(dict(zip(approx_var_names, _samples))) + + # Forward sample variables not in approximation, conditioned on approximation variables + if forward_var_names: + forward_vars = [model_names[name] for name in forward_var_names] + approx_vars = [model_names[name] for name in approx_var_names] + + # Compile forward sampling function + # Variables in vars_in_trace become inputs to the function + # so we can pass the sampled approximation values as inputs + sampler_fn, _ = compile_forward_sampling_function( + outputs=forward_vars, + vars_in_trace=approx_vars, + basic_rvs=model.basic_RVs, + givens_dict=None, + random_seed=random_seed, + ) - return dict(zip(names, _samples)) + # Get the value variables that will be inputs + approx_value_vars = [model.rvs_to_values[var] for var in approx_vars] + + # Forward sample for each draw, conditioned on approximation samples + forward_samples_list = [] + for i in range(draws): + # Create inputs dict with sampled approximation values for this draw + # Use variable names as keys since the compiled function expects string keywords + inputs_dict = {} + for var, value_var in zip(approx_vars, approx_value_vars): + sampled_value = all_samples[value_var.name][i] + inputs_dict[value_var.name] = sampled_value + + # Forward sample with these fixed values + forward_samples = sampler_fn(**inputs_dict) + if isinstance(forward_samples, list): + forward_samples_list.append(forward_samples) + else: + forward_samples_list.append([forward_samples]) + + # Stack results + if forward_samples_list: + if isinstance(forward_samples_list[0], list): + for j, name in enumerate(forward_var_names): + all_samples[name] = np.stack([draw[j] for draw in forward_samples_list]) + else: + all_samples[forward_var_names[0]] = np.stack(forward_samples_list) + + return all_samples return inner @@ -1605,14 +1866,273 @@ def sample( if random_seed is not None: (random_seed,) = _get_seeds_per_chain(random_seed, 1) samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed) + + # Get the variables that correspond to our samples + # We need to find all variables in the model that match our sample names + # This includes both transformed (e.g., 'one_log__') and untransformed (e.g., 'one') variables + sample_names = sorted(samples.keys()) # Use sorted order for consistency + + # Build lookup from unobserved_value_vars first (prioritize these) + # This includes both transformed and untransformed variables + var_name_to_var = {} + for var in model.unobserved_value_vars: + if var.name not in var_name_to_var: + var_name_to_var[var.name] = var + + # Also include named_vars (for variables like 'one' that might not be in unobserved_value_vars) + for name, var in model.named_vars.items(): + if name not in var_name_to_var: + var_name_to_var[name] = var + + # Collect variables in the order matching sample_names, plus untransformed versions + # Use model.unobserved_value_vars as source of truth - it includes both transformed + # and untransformed variables + sample_vars = [] + sample_var_names = set() # Track which var names we've added + + # First, add variables that are in samples (in sample_names order) + for name in sample_names: + if name in var_name_to_var: + var = var_name_to_var[name] + sample_vars.append(var) + sample_var_names.add(var.name) + + # Then, add any variables from model.unobserved_value_vars that aren't already included + # This ensures we include both transformed (e.g., 'one_log__') and untransformed (e.g., 'one') + for var in model.unobserved_value_vars: + if var.name not in sample_var_names: + sample_vars.append(var) + sample_var_names.add(var.name) + + # Create points as OrderedDict to preserve order matching sample_names + # This ensures fn(*point.values()) gets values in the correct order + from collections import OrderedDict + points = ( - {name: np.asarray(records[i]) for name, records in samples.items()} + OrderedDict((name, np.asarray(samples[name][i])) for name in sample_names) for i in range(draws) ) + # Create test_point and var_shapes using the actual sample shapes to ensure trace setup matches + # Key var_shapes by trace variable names (var.name from sample_vars), not sample_names + test_point = OrderedDict() + var_shapes = {} + var_dtypes = {} + trace_varnames = [var.name for var in sample_vars] + + # Create mapping from trace variable names to sample names + # This handles cases where sample_names might differ from trace variable names + trace_to_sample = {} + sample_to_trace = {} # Reverse mapping + for i, var in enumerate(sample_vars): + trace_name = var.name + # Try to find matching sample by name first + if trace_name in sample_names and trace_name in samples: + sample_name = trace_name + elif i < len(sample_names): + # Fall back to index alignment + sample_name = sample_names[i] + else: + continue + trace_to_sample[trace_name] = sample_name + sample_to_trace[sample_name] = trace_name + + # Build test_point, var_shapes, and var_dtypes using actual sample shapes + # Use trace variable names as keys (what trace expects) + # CRITICAL: var_shapes must use trace variable names (var.name) as keys + # and shapes must match the actual samples, not model variable shapes + # For untransformed variables not in samples, we'll compute them from transformed ones + initial_point = model.initial_point() # Get once to compute untransformed vars + + for var in sample_vars: + trace_name = var.name + if trace_name in trace_to_sample: + # Variable is in samples, use the sample value + sample_name = trace_to_sample[trace_name] + if sample_name in samples: + first_sample = np.asarray(samples[sample_name][0]) + test_point[trace_name] = first_sample + var_shapes[trace_name] = first_sample.shape + var_dtypes[trace_name] = first_sample.dtype + else: + # Shouldn't happen, but use initial_point as fallback + if trace_name in initial_point: + test_point[trace_name] = initial_point[trace_name] + var_shapes[trace_name] = initial_point[trace_name].shape + var_dtypes[trace_name] = initial_point[trace_name].dtype + elif trace_name in samples: + # Direct match in samples (shouldn't happen if trace_to_sample is correct) + first_sample = np.asarray(samples[trace_name][0]) + test_point[trace_name] = first_sample + var_shapes[trace_name] = first_sample.shape + var_dtypes[trace_name] = first_sample.dtype + else: + # Variable not in samples - it's an untransformed variable we need to compute + # Use model.initial_point to get the shape (it includes both transformed and untransformed) + if trace_name in initial_point: + test_point[trace_name] = initial_point[trace_name] + var_shapes[trace_name] = initial_point[trace_name].shape + var_dtypes[trace_name] = initial_point[trace_name].dtype + else: + # Variable not in initial_point either - compute shape from model + # We need to compute this variable from transformed variables + # Use a test computation to get the shape + try: + # Try to compute the variable using model.compile_fn to get its shape + # Build test point with transformed variables from samples + test_point_dict = {} + for v in model.value_vars: + if v.name in samples: + test_point_dict[v.name] = samples[v.name][0] + + test_compute_fn = model.compile_fn( + [var], inputs=model.value_vars, on_unused_input="ignore", point_fn=True + ) + # Get shape from the computed value + test_result = test_compute_fn(test_point_dict) + if isinstance(test_result, list | tuple): + test_value = test_result[0] if len(test_result) > 0 else None + else: + test_value = test_result + + if test_value is not None: + test_point[trace_name] = test_value + var_shapes[trace_name] = test_value.shape + var_dtypes[trace_name] = test_value.dtype + else: + raise ValueError(f"Could not compute shape for {trace_name}") + except Exception as e: + # If computation fails, try to get shape from initial_point computation + # Compute initial_point again with all variables + try: + full_initial_point = model.initial_point() + if trace_name in full_initial_point: + test_point[trace_name] = full_initial_point[trace_name] + var_shapes[trace_name] = full_initial_point[trace_name].shape + var_dtypes[trace_name] = full_initial_point[trace_name].dtype + else: + # Last resort: skip if we truly can't determine shape + # But this shouldn't happen for variables in model.unobserved_value_vars + raise ValueError( + f"Could not determine shape for {trace_name}. " + f"Variable should be in model.unobserved_value_vars. " + f"Initial point keys: {list(full_initial_point.keys())}" + ) from e + except Exception as e2: + raise ValueError( + f"Could not determine shape for {trace_name}. " + f"Variable is in sample_vars but not in samples or initial_point." + ) from e2 + + # Create a custom fn that returns values in the order matching trace.varnames + # trace.record calls fn(*point.values()), so we need to return values in trace.varnames order + # point.values() is in sample_names order, but trace.varnames is trace_varnames + # We need to reorder values from sample_names order to trace_varnames order + # For untransformed variables not in samples, we need to compute them from transformed ones + vars_to_compute = [ + var + for var in sample_vars + if var.name not in samples and var.name not in trace_to_sample + ] + + if vars_to_compute: + # We have untransformed variables to compute - need a proper fn that transforms them + # Use model's compile_fn to compute untransformed variables from transformed ones + computed_var_names = [var.name for var in vars_to_compute] + # Get the transformed variables that are inputs (value_vars) + input_vars = model.value_vars + # Compile a function to compute untransformed variables + # Use point_fn=True to get a PointFunc that accepts a dict + compute_fn = model.compile_fn( + vars_to_compute, + inputs=input_vars, + on_unused_input="ignore", + point_fn=True, + ) + else: + compute_fn = None + computed_var_names = [] + input_vars = [] + + def identity_fn(*args): + """Return values reordered to match trace.varnames.""" + # args comes from point.values() in sample_names order + # Build mapping from sample_names to values + value_dict = dict(zip(sample_names, args)) + + # Build input dict for compute_fn if needed + if compute_fn is not None: + # Map sample names to value_vars for compute_fn + # compute_fn expects a dict (point_fn=True) + compute_input_dict = {} + for var in input_vars: + if var.name in value_dict: + compute_input_dict[var.name] = value_dict[var.name] + + # Call with dict (PointFunc expects a dict) + # PointFunc returns a list/tuple of values in the order of outputs + computed_values = compute_fn(compute_input_dict) + if isinstance(computed_values, list | tuple): + # computed_values is in the order of vars_to_compute + if len(computed_values) != len(computed_var_names): + raise ValueError( + f"Mismatch: compute_fn returned {len(computed_values)} values, " + f"expected {len(computed_var_names)}. " + f"computed_var_names: {computed_var_names}" + ) + computed_dict = dict(zip(computed_var_names, computed_values)) + else: + # Single output case + if len(computed_var_names) != 1: + raise ValueError( + f"compute_fn returned single value but expected {len(computed_var_names)} values. " + f"computed_var_names: {computed_var_names}" + ) + computed_dict = {computed_var_names[0]: computed_values} + value_dict.update(computed_dict) + + # Return values in trace_varnames order, mapping trace names to sample names + # CRITICAL: We must return exactly len(trace_varnames) values, one for each variable + result = [] + for trace_name in trace_varnames: + if trace_name in trace_to_sample: + # Variable is mapped to a sample name + sample_name = trace_to_sample[trace_name] + if sample_name in value_dict: + result.append(value_dict[sample_name]) + else: + # Sample name not in value_dict - shouldn't happen + raise ValueError( + f"Sample name '{sample_name}' for trace variable '{trace_name}' not found in value_dict. " + f"Available keys: {list(value_dict.keys())}" + ) + elif trace_name in value_dict: + # Direct match (e.g., computed untransformed variable) + result.append(value_dict[trace_name]) + else: + # Variable not in value_dict - missing value! + raise ValueError( + f"Trace variable '{trace_name}' not found in value_dict. " + f"Available keys: {list(value_dict.keys())}, " + f"trace_varnames: {trace_varnames}, " + f"vars_to_compute: {computed_var_names if compute_fn is not None else []}" + ) + + # Ensure we return exactly len(trace_varnames) values + if len(result) != len(trace_varnames): + raise ValueError( + f"Mismatch in result length: got {len(result)}, expected {len(trace_varnames)}. " + f"trace_varnames: {trace_varnames}, result length: {len(result)}" + ) + return tuple(result) + trace = NDArray( model=model, - test_point={name: records[0] for name, records in samples.items()}, + vars=sample_vars, + test_point=test_point, + fn=identity_fn, + var_shapes=var_shapes, + var_dtypes=var_dtypes, ) try: trace.setup(draws=draws, chain=0) diff --git a/tests/variational/test_inference.py b/tests/variational/test_inference.py index 8effb4ed91..bdd7a5ed79 100644 --- a/tests/variational/test_inference.py +++ b/tests/variational/test_inference.py @@ -175,7 +175,8 @@ def fit_kwargs(inference, use_minibatch): def test_fit_oo(simple_model, inference, fit_kwargs, simple_model_data): - trace = inference.fit(**fit_kwargs).sample(10000, model=simple_model) + with simple_model: + trace = inference.fit(**fit_kwargs).sample(10000) mu_post = simple_model_data["mu_post"] d = simple_model_data["d"] np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05) @@ -244,10 +245,11 @@ def test_fit_fn_text(method, kwargs, error): pm.fit(10, method=method, **kwargs) -def test_profile(inference): +def test_profile(inference, simple_model): if type(inference) in {SVGD, ASVGD}: pytest.skip("Not Implemented Inference") - inference.run_profiling(n=100).summary() + with simple_model: + inference.run_profiling(n=100).summary() @pytest.fixture(scope="module") @@ -267,64 +269,66 @@ def binomial_model_inference(binomial_model, inference_spec): @pytest.mark.xfail("pytensor.config.warn_float64 == 'raise'", reason="too strict float32") -def test_replacements(binomial_model_inference): +def test_replacements(binomial_model_inference, binomial_model): d = pytensor.shared(1) approx = binomial_model_inference.approx - p = approx._model.p - p_t = p**3 - p_s = approx.sample_node(p_t) - assert not any( - isinstance(n.owner.op, pytensor.tensor.random.basic.BetaRV) - for n in pytensor.graph.ancestors([p_s]) - if n.owner - ), "p should be replaced" - if pytensor.config.compute_test_value != "off": - assert p_s.tag.test_value.shape == p_t.tag.test_value.shape - sampled = [pm.draw(p_s) for _ in range(100)] - assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic - p_z = approx.sample_node(p_t, deterministic=False, size=10) - assert p_z.shape.eval() == (10,) - try: - p_z = approx.sample_node(p_t, deterministic=True, size=10) + with binomial_model: + p = binomial_model.p + p_t = p**3 + p_s = approx.sample_node(p_t) + assert not any( + isinstance(n.owner.op, pytensor.tensor.random.basic.BetaRV) + for n in pytensor.graph.ancestors([p_s]) + if n.owner + ), "p should be replaced" + if pytensor.config.compute_test_value != "off": + assert p_s.tag.test_value.shape == p_t.tag.test_value.shape + sampled = [pm.draw(p_s) for _ in range(100)] + assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic + p_z = approx.sample_node(p_t, deterministic=False, size=10) assert p_z.shape.eval() == (10,) - except opvi.NotImplementedInference: - pass - - try: - p_d = approx.sample_node(p_t, deterministic=True) - sampled = [pm.draw(p_d) for _ in range(100)] + try: + p_z = approx.sample_node(p_t, deterministic=True, size=10) + assert p_z.shape.eval() == (10,) + except opvi.NotImplementedInference: + pass + + try: + p_d = approx.sample_node(p_t, deterministic=True) + sampled = [pm.draw(p_d) for _ in range(100)] + assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic + except opvi.NotImplementedInference: + pass + + p_r = approx.sample_node(p_t, deterministic=d) + d.set_value(1) + sampled = [pm.draw(p_r) for _ in range(100)] assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic - except opvi.NotImplementedInference: - pass - - p_r = approx.sample_node(p_t, deterministic=d) - d.set_value(1) - sampled = [pm.draw(p_r) for _ in range(100)] - assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic - d.set_value(0) - sampled = [pm.draw(p_r) for _ in range(100)] - assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic + d.set_value(0) + sampled = [pm.draw(p_r) for _ in range(100)] + assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic -def test_sample_replacements(binomial_model_inference): +def test_sample_replacements(binomial_model_inference, binomial_model): i = pt.iscalar() i.tag.test_value = 1 approx = binomial_model_inference.approx - p = approx._model.p - p_t = p**3 - p_s = approx.sample_node(p_t, size=100) - if pytensor.config.compute_test_value != "off": - assert p_s.tag.test_value.shape == (100, *p_t.tag.test_value.shape) - sampled = p_s.eval() - assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic - assert sampled.shape[0] == 100 - - p_d = approx.sample_node(p_t, size=i) - sampled = p_d.eval({i: 100}) - assert any(map(operator.ne, sampled[1:], sampled[:-1])) # deterministic - assert sampled.shape[0] == 100 - sampled = p_d.eval({i: 101}) - assert sampled.shape[0] == 101 + with binomial_model: + p = binomial_model.p + p_t = p**3 + p_s = approx.sample_node(p_t, size=100) + if pytensor.config.compute_test_value != "off": + assert p_s.tag.test_value.shape == (100, *p_t.tag.test_value.shape) + sampled = p_s.eval() + assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic + assert sampled.shape[0] == 100 + + p_d = approx.sample_node(p_t, size=i) + sampled = p_d.eval({i: 100}) + assert any(map(operator.ne, sampled[1:], sampled[:-1])) # deterministic + assert sampled.shape[0] == 100 + sampled = p_d.eval({i: 101}) + assert sampled.shape[0] == 101 def test_remove_scan_op(): @@ -355,7 +359,7 @@ def test_var_replacement(): def test_clear_cache(): - with pm.Model(): + with pm.Model() as model: pm.Normal("n", 0, 1) inference = ADVI() inference.fit(n=10) @@ -365,19 +369,21 @@ def test_clear_cache(): assert all(len(c) == 0 for c in inference.approx._cache.values()) new_a = cloudpickle.loads(cloudpickle.dumps(inference.approx)) assert not hasattr(new_a, "_cache") - inference_new = pm.KLqp(new_a) - inference_new.fit(n=10) - assert any(len(c) != 0 for c in inference_new.approx._cache.values()) - inference_new.approx._cache.clear() - assert all(len(c) == 0 for c in inference_new.approx._cache.values()) + with model: + inference_new = pm.KLqp(new_a) + inference_new.fit(n=10) + assert any(len(c) != 0 for c in inference_new.approx._cache.values()) + inference_new.approx._cache.clear() + assert all(len(c) == 0 for c in inference_new.approx._cache.values()) -def test_fit_data(inference, fit_kwargs, simple_model_data): - fitted = inference.fit(**fit_kwargs) - mu_post = simple_model_data["mu_post"] - d = simple_model_data["d"] - np.testing.assert_allclose(fitted.mean_data["mu"].values, mu_post, rtol=0.05) - np.testing.assert_allclose(fitted.std_data["mu"], np.sqrt(1.0 / d), rtol=0.2) +def test_fit_data(inference, fit_kwargs, simple_model_data, simple_model): + with simple_model: + fitted = inference.fit(**fit_kwargs) + mu_post = simple_model_data["mu_post"] + d = simple_model_data["d"] + np.testing.assert_allclose(fitted.mean_data["mu"].values, mu_post, rtol=0.05) + np.testing.assert_allclose(fitted.std_data["mu"], np.sqrt(1.0 / d), rtol=0.2) @pytest.fixture @@ -441,13 +447,13 @@ def test_fit_data_coords(hierarchical_model, hierarchical_model_data): with hierarchical_model: fitted = pm.fit(1) - for data in [fitted.mean_data, fitted.std_data]: - assert set(data.keys()) == {"sigma_group_mu_log__", "sigma_log__", "group_mu", "mu"} - assert data["group_mu"].shape == hierarchical_model_data["group_shape"] - assert list(data["group_mu"].coords.keys()) == list( - hierarchical_model_data["group_coords"].keys() - ) - assert data["mu"].shape == () + for data in [fitted.mean_data, fitted.std_data]: + assert set(data.keys()) == {"sigma_group_mu_log__", "sigma_log__", "group_mu", "mu"} + assert data["group_mu"].shape == hierarchical_model_data["group_shape"] + assert list(data["group_mu"].coords.keys()) == list( + hierarchical_model_data["group_coords"].keys() + ) + assert data["mu"].shape == () def test_multiple_minibatch_variables(): diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index ffbfbbf090..41460f14d5 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -184,48 +184,55 @@ def test_init_groups(three_var_model, raises, grouping): ids=lambda t: ", ".join(f"{k.__name__}: {v[0]}" for k, v in t[1].items()), ) def three_var_groups(request, three_var_model): - kw, grouping = request.param - approxes, groups = zip(*grouping.items()) - groups, gkwargs = zip(*groups) - groups = [ - list(map(ft.partial(getattr, three_var_model), g)) if g is not None else None - for g in groups - ] - inited_groups = [ - a(group=g, model=three_var_model, **gk) for a, g, gk in zip(approxes, groups, gkwargs) - ] + with three_var_model: + kw, grouping = request.param + approxes, groups = zip(*grouping.items()) + groups, gkwargs = zip(*groups) + groups = [ + list(map(ft.partial(getattr, three_var_model), g)) if g is not None else None + for g in groups + ] + inited_groups = [ + a(group=g, model=three_var_model, **gk) for a, g, gk in zip(approxes, groups, gkwargs) + ] return inited_groups @pytest.fixture def three_var_approx(three_var_model, three_var_groups): - approx = opvi.Approximation(three_var_groups, model=three_var_model) + with three_var_model: + approx = opvi.Approximation(three_var_groups, model=three_var_model) return approx @pytest.fixture def three_var_approx_single_group_mf(three_var_model): - return MeanField(model=three_var_model) + with three_var_model: + approx = MeanField(model=three_var_model) + return approx -def test_pickle_approx(three_var_approx): +def test_pickle_approx(three_var_approx, three_var_model): import cloudpickle dump = cloudpickle.dumps(three_var_approx) new = cloudpickle.loads(dump) - assert new.sample(1, model=new._model) + with three_var_model: + assert new.sample(1) -def test_pickle_single_group(three_var_approx_single_group_mf): +def test_pickle_single_group(three_var_approx_single_group_mf, three_var_model): import cloudpickle dump = cloudpickle.dumps(three_var_approx_single_group_mf) new = cloudpickle.loads(dump) - assert new.sample(1, model=new._model) + with three_var_model: + assert new.sample(1) -def test_sample_simple(three_var_approx): - trace = three_var_approx.sample(100, model=three_var_approx._model, return_inferencedata=False) +def test_sample_simple(three_var_approx, three_var_model): + with three_var_model: + trace = three_var_approx.sample(100, return_inferencedata=False) assert set(trace.varnames) == {"one", "one_log__", "three", "two"} assert len(trace) == 100 assert trace[0]["one"].shape == (10, 2) @@ -246,39 +253,42 @@ def parametric_grouped_approxes(request): def test_logq_mini_1_sample_1_var(parametric_grouped_approxes, three_var_model): cls, kw = parametric_grouped_approxes - approx = cls([three_var_model.one], model=three_var_model, **kw) - logq = approx.logq - logq = approx.set_size_and_deterministic(logq, 1, 0) - logq.eval() + with three_var_model: + approx = cls([three_var_model.one], model=three_var_model, **kw) + logq = approx.logq + logq = approx.set_size_and_deterministic(logq, 1, 0) + logq.eval() def test_logq_mini_2_sample_2_var(parametric_grouped_approxes, three_var_model): cls, kw = parametric_grouped_approxes - approx = cls([three_var_model.one, three_var_model.two], model=three_var_model, **kw) - logq = approx.logq - logq = approx.set_size_and_deterministic(logq, 2, 0) - logq.eval() + with three_var_model: + approx = cls([three_var_model.one, three_var_model.two], model=three_var_model, **kw) + logq = approx.logq + logq = approx.set_size_and_deterministic(logq, 2, 0) + logq.eval() -def test_logq_globals(three_var_approx): +def test_logq_globals(three_var_approx, three_var_model): if not three_var_approx.has_logq: pytest.skip(f"{three_var_approx} does not implement logq") - approx = three_var_approx - logq, symbolic_logq = approx.set_size_and_deterministic( - [approx.logq, approx.symbolic_logq], 1, 0 - ) - e = logq.eval() - es = symbolic_logq.eval() - assert e.shape == () - assert es.shape == (1,) - - logq, symbolic_logq = approx.set_size_and_deterministic( - [approx.logq, approx.symbolic_logq], 2, 0 - ) - e = logq.eval() - es = symbolic_logq.eval() - assert e.shape == () - assert es.shape == (2,) + with three_var_model: + approx = three_var_approx + logq, symbolic_logq = approx.set_size_and_deterministic( + [approx.logq, approx.symbolic_logq], 1, 0 + ) + e = logq.eval() + es = symbolic_logq.eval() + assert e.shape == () + assert es.shape == (1,) + + logq, symbolic_logq = approx.set_size_and_deterministic( + [approx.logq, approx.symbolic_logq], 2, 0 + ) + e = logq.eval() + es = symbolic_logq.eval() + assert e.shape == () + assert es.shape == (2,) def test_symbolic_normalizing_constant_no_rvs(): @@ -292,5 +302,8 @@ def test_symbolic_normalizing_constant_no_rvs(): y_hat = pm.Flat("y_hat", observed=obs_batch, total_size=1000) step = pm.ADVI() + # Access property within model context + symbolic_normalizing = step.approx.symbolic_normalizing_constant - assert_no_rvs(step.approx.symbolic_normalizing_constant) + # Access the property again to test it doesn't require model context after first access + assert_no_rvs(symbolic_normalizing) From 41cfc20c0a209e436c831ae8ad502a4962bb9dfc Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Fri, 31 Oct 2025 17:42:54 +0100 Subject: [PATCH 05/11] Cleanup --- pymc/variational/opvi.py | 511 ++++++++++++++------------------------- 1 file changed, 179 insertions(+), 332 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 78c6b3ee10..74a94cc61a 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -52,6 +52,7 @@ import itertools import warnings +from dataclasses import dataclass from typing import Any, overload import numpy as np @@ -1377,6 +1378,16 @@ def _rebuild_group_mappings(group, model): group.ordering = ordering +@dataclass +class TraceSpec: + sample_vars: list + test_point: collections.OrderedDict + computed_var_names: list[str] + input_vars: list + compute_fn: Any + value_var_names: list[str] + + class Approximation(WithMemoization): """**Wrapper for grouped approximations**. @@ -1467,6 +1478,151 @@ def collect(self, item): self._ensure_groups_ready() return [getattr(g, item) for g in self.groups] + def _variational_orderings(self, model): + orderings = collections.OrderedDict() + for g in self.groups: + mapped = _refresh_group_for_model(g, model) + if mapped: + orderings.update(g.ordering) + return orderings + + def _draw_variational_samples(self, model, names, draws, size_sym, random_seed): + if not names: + return {} + tensors = [self.rslice(name, model) for name in names] + tensors = self.set_size_and_deterministic(tensors, size_sym, 0) + sample_fn = compile([size_sym], tensors) + rng_nodes = find_rng_nodes(tensors) + if random_seed is not None: + reseed_rngs(rng_nodes, random_seed) + outputs = sample_fn(draws) + if not isinstance(outputs, list | tuple): + outputs = [outputs] + return dict(zip(names, outputs)) + + def _draw_forward_samples(self, model, approx_samples, approx_names, draws, random_seed): + from pymc.sampling.forward import compile_forward_sampling_function + + model_names = {model.rvs_to_values[v].name: v for v in model.free_RVs} + forward_names = sorted(name for name in model_names if name not in approx_names) + if not forward_names: + return {} + + forward_vars = [model_names[name] for name in forward_names] + approx_vars = [model_names[name] for name in approx_names if name in model_names] + sampler_fn, _ = compile_forward_sampling_function( + outputs=forward_vars, + vars_in_trace=approx_vars, + basic_rvs=model.basic_RVs, + givens_dict=None, + random_seed=random_seed, + ) + approx_value_vars = [model.rvs_to_values[var] for var in approx_vars] + stacked = {name: [] for name in forward_names} + for i in range(draws): + inputs = { + value_var.name: approx_samples[value_var.name][i] for value_var in approx_value_vars + } + raw = sampler_fn(**inputs) + if not isinstance(raw, list | tuple): + raw = [raw] + for name, value in zip(forward_names, raw): + stacked[name].append(value) + return {name: np.stack(values) for name, values in stacked.items()} + + def _collect_sample_vars(self, model, sample_names): + lookup = {} + for var in model.unobserved_value_vars: + lookup.setdefault(var.name, var) + for name, var in model.named_vars.items(): + lookup.setdefault(name, var) + sample_vars = [lookup[name] for name in sample_names if name in lookup] + seen = {var.name for var in sample_vars} + for var in model.unobserved_value_vars: + if var.name not in seen: + sample_vars.append(var) + return sample_vars, lookup + + def _compute_missing_trace_values(self, model, samples, missing_vars): + if not missing_vars: + return {}, [], [], None + input_vars = model.value_vars + base_point = model.initial_point() + point = { + var.name: np.asarray(samples[var.name][0]) + if var.name in samples + else base_point[var.name] + for var in input_vars + if var.name in samples or var.name in base_point + } + compute_fn = model.compile_fn( + missing_vars, + inputs=input_vars, + on_unused_input="ignore", + point_fn=True, + ) + raw_values = compute_fn(point) + if not isinstance(raw_values, list | tuple): + raw_values = [raw_values] + values = {var.name: np.asarray(value) for var, value in zip(missing_vars, raw_values)} + return values, [var.name for var in missing_vars], list(input_vars), compute_fn + + def _build_trace_spec(self, model, samples): + sample_names = sorted(samples.keys()) + sample_vars, _ = self._collect_sample_vars(model, sample_names) + initial_point = model.initial_point() + test_point = collections.OrderedDict() + missing_vars = [] + + for var in sample_vars: + trace_name = var.name + if trace_name in samples: + first_sample = np.asarray(samples[trace_name][0]) + test_point[trace_name] = first_sample + continue + if trace_name in initial_point: + value = np.asarray(initial_point[trace_name]) + test_point[trace_name] = value + continue + missing_vars.append(var) + + values, computed_var_names, input_vars, compute_fn = self._compute_missing_trace_values( + model, samples, missing_vars + ) + for name, value in values.items(): + test_point[name] = value + + return TraceSpec( + sample_vars=sample_vars, + test_point=test_point, + computed_var_names=computed_var_names, + input_vars=input_vars, + compute_fn=compute_fn, + value_var_names=[var.name for var in model.value_vars], + ) + + def _augment_samples_with_computed(self, model, samples, spec, draws): + if not spec.computed_var_names: + return + + computed = {name: [] for name in spec.computed_var_names} + input_names = [var.name for var in spec.input_vars] + for i in range(draws): + inputs = {} + for name in input_names: + if name in samples: + inputs[name] = samples[name][i] + else: + inputs[name] = spec.test_point[name] + outputs = spec.compute_fn(inputs) + if not isinstance(outputs, list | tuple): + outputs = [outputs] + for name, value in zip(spec.computed_var_names, outputs): + computed[name].append(np.asarray(value)) + + for name, values in computed.items(): + samples[name] = np.stack(values) + inputs = property(lambda self: self.collect("input")) symbolic_randoms = property(lambda self: self.collect("symbolic_random")) @@ -1754,80 +1910,16 @@ def sample_dict_fn(self): s = pt.iscalar() def inner(draws=100, *, model=None, random_seed: SeedSequenceSeed = None): - from pymc.sampling.forward import compile_forward_sampling_function - model = modelcontext(model) - - # Get all variable names that exist in the approximation - approx_names = set() - for ordering in self.collect("ordering"): - approx_names.update(ordering.keys()) - - # Get all variable names from the model - model_names = {model.rvs_to_values[v].name: v for v in model.free_RVs} - - # Separate variables into those in approximation and those not - approx_var_names = sorted(approx_names & set(model_names.keys())) - forward_var_names = sorted(set(model_names.keys()) - approx_names) - - # Sample variables from approximation - all_samples = {} - if approx_var_names: - sampled = [self.rslice(name, model) for name in approx_var_names] - sampled = self.set_size_and_deterministic(sampled, s, 0) - sample_fn = compile([s], sampled) - rng_nodes = find_rng_nodes(sampled) - - if random_seed is not None: - reseed_rngs(rng_nodes, random_seed) - _samples = sample_fn(draws) - all_samples.update(dict(zip(approx_var_names, _samples))) - - # Forward sample variables not in approximation, conditioned on approximation variables - if forward_var_names: - forward_vars = [model_names[name] for name in forward_var_names] - approx_vars = [model_names[name] for name in approx_var_names] - - # Compile forward sampling function - # Variables in vars_in_trace become inputs to the function - # so we can pass the sampled approximation values as inputs - sampler_fn, _ = compile_forward_sampling_function( - outputs=forward_vars, - vars_in_trace=approx_vars, - basic_rvs=model.basic_RVs, - givens_dict=None, - random_seed=random_seed, - ) - - # Get the value variables that will be inputs - approx_value_vars = [model.rvs_to_values[var] for var in approx_vars] - - # Forward sample for each draw, conditioned on approximation samples - forward_samples_list = [] - for i in range(draws): - # Create inputs dict with sampled approximation values for this draw - # Use variable names as keys since the compiled function expects string keywords - inputs_dict = {} - for var, value_var in zip(approx_vars, approx_value_vars): - sampled_value = all_samples[value_var.name][i] - inputs_dict[value_var.name] = sampled_value - - # Forward sample with these fixed values - forward_samples = sampler_fn(**inputs_dict) - if isinstance(forward_samples, list): - forward_samples_list.append(forward_samples) - else: - forward_samples_list.append([forward_samples]) - - # Stack results - if forward_samples_list: - if isinstance(forward_samples_list[0], list): - for j, name in enumerate(forward_var_names): - all_samples[name] = np.stack([draw[j] for draw in forward_samples_list]) - else: - all_samples[forward_var_names[0]] = np.stack(forward_samples_list) - - return all_samples + orderings = self._variational_orderings(model) + approx_var_names = sorted(orderings.keys()) + approx_samples = self._draw_variational_samples( + model, approx_var_names, draws, s, random_seed + ) + forward_samples = self._draw_forward_samples( + model, approx_samples, approx_var_names, draws, random_seed + ) + return {**approx_samples, **forward_samples} return inner @@ -1858,7 +1950,6 @@ def sample( trace: :class:`pymc.backends.base.MultiTrace` Samples drawn from variational posterior. """ - # TODO: add tests for include_transformed case kwargs["log_likelihood"] = False model = modelcontext(model) @@ -1866,273 +1957,29 @@ def sample( if random_seed is not None: (random_seed,) = _get_seeds_per_chain(random_seed, 1) samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed) + spec = self._build_trace_spec(model, samples) + self._augment_samples_with_computed(model, samples, spec, draws) + if spec.computed_var_names: + spec = self._build_trace_spec(model, samples) - # Get the variables that correspond to our samples - # We need to find all variables in the model that match our sample names - # This includes both transformed (e.g., 'one_log__') and untransformed (e.g., 'one') variables - sample_names = sorted(samples.keys()) # Use sorted order for consistency - - # Build lookup from unobserved_value_vars first (prioritize these) - # This includes both transformed and untransformed variables - var_name_to_var = {} - for var in model.unobserved_value_vars: - if var.name not in var_name_to_var: - var_name_to_var[var.name] = var - - # Also include named_vars (for variables like 'one' that might not be in unobserved_value_vars) - for name, var in model.named_vars.items(): - if name not in var_name_to_var: - var_name_to_var[name] = var - - # Collect variables in the order matching sample_names, plus untransformed versions - # Use model.unobserved_value_vars as source of truth - it includes both transformed - # and untransformed variables - sample_vars = [] - sample_var_names = set() # Track which var names we've added - - # First, add variables that are in samples (in sample_names order) - for name in sample_names: - if name in var_name_to_var: - var = var_name_to_var[name] - sample_vars.append(var) - sample_var_names.add(var.name) - - # Then, add any variables from model.unobserved_value_vars that aren't already included - # This ensures we include both transformed (e.g., 'one_log__') and untransformed (e.g., 'one') - for var in model.unobserved_value_vars: - if var.name not in sample_var_names: - sample_vars.append(var) - sample_var_names.add(var.name) - - # Create points as OrderedDict to preserve order matching sample_names - # This ensures fn(*point.values()) gets values in the correct order from collections import OrderedDict + default_point = model.initial_point() points = ( - OrderedDict((name, np.asarray(samples[name][i])) for name in sample_names) + OrderedDict( + ( + name, + np.asarray(samples[name][i]) + if name in samples and len(samples[name]) > i + else np.asarray(spec.test_point.get(name, default_point[name])), + ) + for name in spec.value_var_names + ) for i in range(draws) ) - # Create test_point and var_shapes using the actual sample shapes to ensure trace setup matches - # Key var_shapes by trace variable names (var.name from sample_vars), not sample_names - test_point = OrderedDict() - var_shapes = {} - var_dtypes = {} - trace_varnames = [var.name for var in sample_vars] - - # Create mapping from trace variable names to sample names - # This handles cases where sample_names might differ from trace variable names - trace_to_sample = {} - sample_to_trace = {} # Reverse mapping - for i, var in enumerate(sample_vars): - trace_name = var.name - # Try to find matching sample by name first - if trace_name in sample_names and trace_name in samples: - sample_name = trace_name - elif i < len(sample_names): - # Fall back to index alignment - sample_name = sample_names[i] - else: - continue - trace_to_sample[trace_name] = sample_name - sample_to_trace[sample_name] = trace_name - - # Build test_point, var_shapes, and var_dtypes using actual sample shapes - # Use trace variable names as keys (what trace expects) - # CRITICAL: var_shapes must use trace variable names (var.name) as keys - # and shapes must match the actual samples, not model variable shapes - # For untransformed variables not in samples, we'll compute them from transformed ones - initial_point = model.initial_point() # Get once to compute untransformed vars - - for var in sample_vars: - trace_name = var.name - if trace_name in trace_to_sample: - # Variable is in samples, use the sample value - sample_name = trace_to_sample[trace_name] - if sample_name in samples: - first_sample = np.asarray(samples[sample_name][0]) - test_point[trace_name] = first_sample - var_shapes[trace_name] = first_sample.shape - var_dtypes[trace_name] = first_sample.dtype - else: - # Shouldn't happen, but use initial_point as fallback - if trace_name in initial_point: - test_point[trace_name] = initial_point[trace_name] - var_shapes[trace_name] = initial_point[trace_name].shape - var_dtypes[trace_name] = initial_point[trace_name].dtype - elif trace_name in samples: - # Direct match in samples (shouldn't happen if trace_to_sample is correct) - first_sample = np.asarray(samples[trace_name][0]) - test_point[trace_name] = first_sample - var_shapes[trace_name] = first_sample.shape - var_dtypes[trace_name] = first_sample.dtype - else: - # Variable not in samples - it's an untransformed variable we need to compute - # Use model.initial_point to get the shape (it includes both transformed and untransformed) - if trace_name in initial_point: - test_point[trace_name] = initial_point[trace_name] - var_shapes[trace_name] = initial_point[trace_name].shape - var_dtypes[trace_name] = initial_point[trace_name].dtype - else: - # Variable not in initial_point either - compute shape from model - # We need to compute this variable from transformed variables - # Use a test computation to get the shape - try: - # Try to compute the variable using model.compile_fn to get its shape - # Build test point with transformed variables from samples - test_point_dict = {} - for v in model.value_vars: - if v.name in samples: - test_point_dict[v.name] = samples[v.name][0] - - test_compute_fn = model.compile_fn( - [var], inputs=model.value_vars, on_unused_input="ignore", point_fn=True - ) - # Get shape from the computed value - test_result = test_compute_fn(test_point_dict) - if isinstance(test_result, list | tuple): - test_value = test_result[0] if len(test_result) > 0 else None - else: - test_value = test_result - - if test_value is not None: - test_point[trace_name] = test_value - var_shapes[trace_name] = test_value.shape - var_dtypes[trace_name] = test_value.dtype - else: - raise ValueError(f"Could not compute shape for {trace_name}") - except Exception as e: - # If computation fails, try to get shape from initial_point computation - # Compute initial_point again with all variables - try: - full_initial_point = model.initial_point() - if trace_name in full_initial_point: - test_point[trace_name] = full_initial_point[trace_name] - var_shapes[trace_name] = full_initial_point[trace_name].shape - var_dtypes[trace_name] = full_initial_point[trace_name].dtype - else: - # Last resort: skip if we truly can't determine shape - # But this shouldn't happen for variables in model.unobserved_value_vars - raise ValueError( - f"Could not determine shape for {trace_name}. " - f"Variable should be in model.unobserved_value_vars. " - f"Initial point keys: {list(full_initial_point.keys())}" - ) from e - except Exception as e2: - raise ValueError( - f"Could not determine shape for {trace_name}. " - f"Variable is in sample_vars but not in samples or initial_point." - ) from e2 - - # Create a custom fn that returns values in the order matching trace.varnames - # trace.record calls fn(*point.values()), so we need to return values in trace.varnames order - # point.values() is in sample_names order, but trace.varnames is trace_varnames - # We need to reorder values from sample_names order to trace_varnames order - # For untransformed variables not in samples, we need to compute them from transformed ones - vars_to_compute = [ - var - for var in sample_vars - if var.name not in samples and var.name not in trace_to_sample - ] - - if vars_to_compute: - # We have untransformed variables to compute - need a proper fn that transforms them - # Use model's compile_fn to compute untransformed variables from transformed ones - computed_var_names = [var.name for var in vars_to_compute] - # Get the transformed variables that are inputs (value_vars) - input_vars = model.value_vars - # Compile a function to compute untransformed variables - # Use point_fn=True to get a PointFunc that accepts a dict - compute_fn = model.compile_fn( - vars_to_compute, - inputs=input_vars, - on_unused_input="ignore", - point_fn=True, - ) - else: - compute_fn = None - computed_var_names = [] - input_vars = [] - - def identity_fn(*args): - """Return values reordered to match trace.varnames.""" - # args comes from point.values() in sample_names order - # Build mapping from sample_names to values - value_dict = dict(zip(sample_names, args)) - - # Build input dict for compute_fn if needed - if compute_fn is not None: - # Map sample names to value_vars for compute_fn - # compute_fn expects a dict (point_fn=True) - compute_input_dict = {} - for var in input_vars: - if var.name in value_dict: - compute_input_dict[var.name] = value_dict[var.name] - - # Call with dict (PointFunc expects a dict) - # PointFunc returns a list/tuple of values in the order of outputs - computed_values = compute_fn(compute_input_dict) - if isinstance(computed_values, list | tuple): - # computed_values is in the order of vars_to_compute - if len(computed_values) != len(computed_var_names): - raise ValueError( - f"Mismatch: compute_fn returned {len(computed_values)} values, " - f"expected {len(computed_var_names)}. " - f"computed_var_names: {computed_var_names}" - ) - computed_dict = dict(zip(computed_var_names, computed_values)) - else: - # Single output case - if len(computed_var_names) != 1: - raise ValueError( - f"compute_fn returned single value but expected {len(computed_var_names)} values. " - f"computed_var_names: {computed_var_names}" - ) - computed_dict = {computed_var_names[0]: computed_values} - value_dict.update(computed_dict) - - # Return values in trace_varnames order, mapping trace names to sample names - # CRITICAL: We must return exactly len(trace_varnames) values, one for each variable - result = [] - for trace_name in trace_varnames: - if trace_name in trace_to_sample: - # Variable is mapped to a sample name - sample_name = trace_to_sample[trace_name] - if sample_name in value_dict: - result.append(value_dict[sample_name]) - else: - # Sample name not in value_dict - shouldn't happen - raise ValueError( - f"Sample name '{sample_name}' for trace variable '{trace_name}' not found in value_dict. " - f"Available keys: {list(value_dict.keys())}" - ) - elif trace_name in value_dict: - # Direct match (e.g., computed untransformed variable) - result.append(value_dict[trace_name]) - else: - # Variable not in value_dict - missing value! - raise ValueError( - f"Trace variable '{trace_name}' not found in value_dict. " - f"Available keys: {list(value_dict.keys())}, " - f"trace_varnames: {trace_varnames}, " - f"vars_to_compute: {computed_var_names if compute_fn is not None else []}" - ) - - # Ensure we return exactly len(trace_varnames) values - if len(result) != len(trace_varnames): - raise ValueError( - f"Mismatch in result length: got {len(result)}, expected {len(trace_varnames)}. " - f"trace_varnames: {trace_varnames}, result length: {len(result)}" - ) - return tuple(result) - trace = NDArray( model=model, - vars=sample_vars, - test_point=test_point, - fn=identity_fn, - var_shapes=var_shapes, - var_dtypes=var_dtypes, ) try: trace.setup(draws=draws, chain=0) From 2d052036b10078c7a5fd139ea8ca7c6a0a6677db Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Fri, 31 Oct 2025 18:10:57 +0100 Subject: [PATCH 06/11] Cleanup --- pymc/variational/opvi.py | 44 +++++----------------------------------- 1 file changed, 5 insertions(+), 39 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 74a94cc61a..f2d71684b0 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1382,10 +1382,6 @@ def _rebuild_group_mappings(group, model): class TraceSpec: sample_vars: list test_point: collections.OrderedDict - computed_var_names: list[str] - input_vars: list - compute_fn: Any - value_var_names: list[str] class Approximation(WithMemoization): @@ -1545,7 +1541,7 @@ def _collect_sample_vars(self, model, sample_names): def _compute_missing_trace_values(self, model, samples, missing_vars): if not missing_vars: - return {}, [], [], None + return {} input_vars = model.value_vars base_point = model.initial_point() point = { @@ -1565,7 +1561,7 @@ def _compute_missing_trace_values(self, model, samples, missing_vars): if not isinstance(raw_values, list | tuple): raw_values = [raw_values] values = {var.name: np.asarray(value) for var, value in zip(missing_vars, raw_values)} - return values, [var.name for var in missing_vars], list(input_vars), compute_fn + return values def _build_trace_spec(self, model, samples): sample_names = sorted(samples.keys()) @@ -1586,43 +1582,15 @@ def _build_trace_spec(self, model, samples): continue missing_vars.append(var) - values, computed_var_names, input_vars, compute_fn = self._compute_missing_trace_values( - model, samples, missing_vars - ) + values = self._compute_missing_trace_values(model, samples, missing_vars) for name, value in values.items(): test_point[name] = value return TraceSpec( sample_vars=sample_vars, test_point=test_point, - computed_var_names=computed_var_names, - input_vars=input_vars, - compute_fn=compute_fn, - value_var_names=[var.name for var in model.value_vars], ) - def _augment_samples_with_computed(self, model, samples, spec, draws): - if not spec.computed_var_names: - return - - computed = {name: [] for name in spec.computed_var_names} - input_names = [var.name for var in spec.input_vars] - for i in range(draws): - inputs = {} - for name in input_names: - if name in samples: - inputs[name] = samples[name][i] - else: - inputs[name] = spec.test_point[name] - outputs = spec.compute_fn(inputs) - if not isinstance(outputs, list | tuple): - outputs = [outputs] - for name, value in zip(spec.computed_var_names, outputs): - computed[name].append(np.asarray(value)) - - for name, values in computed.items(): - samples[name] = np.stack(values) - inputs = property(lambda self: self.collect("input")) symbolic_randoms = property(lambda self: self.collect("symbolic_random")) @@ -1958,13 +1926,11 @@ def sample( (random_seed,) = _get_seeds_per_chain(random_seed, 1) samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed) spec = self._build_trace_spec(model, samples) - self._augment_samples_with_computed(model, samples, spec, draws) - if spec.computed_var_names: - spec = self._build_trace_spec(model, samples) from collections import OrderedDict default_point = model.initial_point() + value_var_names = [var.name for var in model.value_vars] points = ( OrderedDict( ( @@ -1973,7 +1939,7 @@ def sample( if name in samples and len(samples[name]) > i else np.asarray(spec.test_point.get(name, default_point[name])), ) - for name in spec.value_var_names + for name in value_var_names ) for i in range(draws) ) From 5c15eb59953a29fba302155b329a7f532ee72bb2 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Fri, 31 Oct 2025 18:55:28 +0100 Subject: [PATCH 07/11] Use model context everywhere --- pymc/variational/opvi.py | 298 ++++++++++++----------- tests/variational/test_approximations.py | 10 +- tests/variational/test_opvi.py | 4 +- 3 files changed, 161 insertions(+), 151 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index f2d71684b0..7d88f7aaa5 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1428,26 +1428,27 @@ def __init__(self, groups, model=None): self.groups = [] seen = set() rest = None - for g in groups: - if g.group is None: - if rest is not None: - raise GroupError("More than one group is specified for the rest variables") + with model: + for g in groups: + if g.group is None: + if rest is not None: + raise GroupError("More than one group is specified for the rest variables") + else: + rest = g else: - rest = g - else: - final_group = _refresh_group_for_model(g, model) - if set(final_group) & seen: - raise GroupError("Found duplicates in groups") - seen.update(final_group) - self.groups.append(g) - # List iteration to preserve order for reproducibility between runs - unseen_free_RVs = [var for var in model.free_RVs if var not in seen] - if unseen_free_RVs: - if rest is None: - raise GroupError("No approximation is specified for the rest variables") - else: - _refresh_group_for_model(rest, model, unseen_free_RVs) - self.groups.append(rest) + final_group = _refresh_group_for_model(g, model) + if set(final_group) & seen: + raise GroupError("Found duplicates in groups") + seen.update(final_group) + self.groups.append(g) + # List iteration to preserve order for reproducibility between runs + unseen_free_RVs = [var for var in model.free_RVs if var not in seen] + if unseen_free_RVs: + if rest is None: + raise GroupError("No approximation is specified for the rest variables") + else: + rest.__init_group__(unseen_free_RVs) + self.groups.append(rest) @property def has_logq(self): @@ -1467,11 +1468,13 @@ def _ensure_groups_ready(self, model=None): model = modelcontext(model) except TypeError: return - for g in self.groups: - _refresh_group_for_model(g, model) + with model: + for g in self.groups: + _refresh_group_for_model(g, model) def collect(self, item): - self._ensure_groups_ready() + model = modelcontext(None) + self._ensure_groups_ready(model=model) return [getattr(g, item) for g in self.groups] def _variational_orderings(self, model): @@ -1483,48 +1486,51 @@ def _variational_orderings(self, model): return orderings def _draw_variational_samples(self, model, names, draws, size_sym, random_seed): - if not names: - return {} - tensors = [self.rslice(name, model) for name in names] - tensors = self.set_size_and_deterministic(tensors, size_sym, 0) - sample_fn = compile([size_sym], tensors) - rng_nodes = find_rng_nodes(tensors) - if random_seed is not None: - reseed_rngs(rng_nodes, random_seed) - outputs = sample_fn(draws) - if not isinstance(outputs, list | tuple): - outputs = [outputs] - return dict(zip(names, outputs)) + with model: + if not names: + return {} + tensors = [self.rslice(name, model) for name in names] + tensors = self.set_size_and_deterministic(tensors, size_sym, 0) + sample_fn = compile([size_sym], tensors) + rng_nodes = find_rng_nodes(tensors) + if random_seed is not None: + reseed_rngs(rng_nodes, random_seed) + outputs = sample_fn(draws) + if not isinstance(outputs, list | tuple): + outputs = [outputs] + return dict(zip(names, outputs)) def _draw_forward_samples(self, model, approx_samples, approx_names, draws, random_seed): from pymc.sampling.forward import compile_forward_sampling_function - model_names = {model.rvs_to_values[v].name: v for v in model.free_RVs} - forward_names = sorted(name for name in model_names if name not in approx_names) - if not forward_names: - return {} - - forward_vars = [model_names[name] for name in forward_names] - approx_vars = [model_names[name] for name in approx_names if name in model_names] - sampler_fn, _ = compile_forward_sampling_function( - outputs=forward_vars, - vars_in_trace=approx_vars, - basic_rvs=model.basic_RVs, - givens_dict=None, - random_seed=random_seed, - ) - approx_value_vars = [model.rvs_to_values[var] for var in approx_vars] - stacked = {name: [] for name in forward_names} - for i in range(draws): - inputs = { - value_var.name: approx_samples[value_var.name][i] for value_var in approx_value_vars - } - raw = sampler_fn(**inputs) - if not isinstance(raw, list | tuple): - raw = [raw] - for name, value in zip(forward_names, raw): - stacked[name].append(value) - return {name: np.stack(values) for name, values in stacked.items()} + with model: + model_names = {model.rvs_to_values[v].name: v for v in model.free_RVs} + forward_names = sorted(name for name in model_names if name not in approx_names) + if not forward_names: + return {} + + forward_vars = [model_names[name] for name in forward_names] + approx_vars = [model_names[name] for name in approx_names if name in model_names] + sampler_fn, _ = compile_forward_sampling_function( + outputs=forward_vars, + vars_in_trace=approx_vars, + basic_rvs=model.basic_RVs, + givens_dict=None, + random_seed=random_seed, + ) + approx_value_vars = [model.rvs_to_values[var] for var in approx_vars] + stacked = {name: [] for name in forward_names} + for i in range(draws): + inputs = { + value_var.name: approx_samples[value_var.name][i] + for value_var in approx_value_vars + } + raw = sampler_fn(**inputs) + if not isinstance(raw, list | tuple): + raw = [raw] + for name, value in zip(forward_names, raw): + stacked[name].append(value) + return {name: np.stack(values) for name, values in stacked.items()} def _collect_sample_vars(self, model, sample_names): lookup = {} @@ -1540,28 +1546,29 @@ def _collect_sample_vars(self, model, sample_names): return sample_vars, lookup def _compute_missing_trace_values(self, model, samples, missing_vars): - if not missing_vars: - return {} - input_vars = model.value_vars - base_point = model.initial_point() - point = { - var.name: np.asarray(samples[var.name][0]) - if var.name in samples - else base_point[var.name] - for var in input_vars - if var.name in samples or var.name in base_point - } - compute_fn = model.compile_fn( - missing_vars, - inputs=input_vars, - on_unused_input="ignore", - point_fn=True, - ) - raw_values = compute_fn(point) - if not isinstance(raw_values, list | tuple): - raw_values = [raw_values] - values = {var.name: np.asarray(value) for var, value in zip(missing_vars, raw_values)} - return values + with model: + if not missing_vars: + return {} + input_vars = model.value_vars + base_point = model.initial_point() + point = { + var.name: np.asarray(samples[var.name][0]) + if var.name in samples + else base_point[var.name] + for var in input_vars + if var.name in samples or var.name in base_point + } + compute_fn = model.compile_fn( + missing_vars, + inputs=input_vars, + on_unused_input="ignore", + point_fn=True, + ) + raw_values = compute_fn(point) + if not isinstance(raw_values, list | tuple): + raw_values = [raw_values] + values = {var.name: np.asarray(value) for var, value in zip(missing_vars, raw_values)} + return values def _build_trace_spec(self, model, samples): sample_names = sorted(samples.keys()) @@ -1819,7 +1826,7 @@ def get_optimization_replacements(self, s, d): return repl @pytensor.config.change_flags(compute_test_value="off") - def sample_node(self, node, size=None, deterministic=False, more_replacements=None): + def sample_node(self, node, size=None, deterministic=False, more_replacements=None, model=None): """Sample given node or nodes over shared posterior. Parameters @@ -1839,22 +1846,22 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No """ node_in = node - model = modelcontext(None) - - if more_replacements: - node = graph_replace(node, more_replacements, strict=False) - if not isinstance(node, list | tuple): - node = [node] - node = model.replace_rvs_by_values(node) - if not isinstance(node_in, list | tuple): - node = node[0] - if size is None: - node_out = self.symbolic_single_sample(node) - else: - node_out = self.symbolic_sample_over_posterior(node) - node_out = self.set_size_and_deterministic(node_out, size, deterministic) - try_to_set_test_value(node_in, node_out, size) - return node_out + model = modelcontext(model) + with model: + if more_replacements: + node = graph_replace(node, more_replacements, strict=False) + if not isinstance(node, list | tuple): + node = [node] + node = model.replace_rvs_by_values(node) + if not isinstance(node_in, list | tuple): + node = node[0] + if size is None: + node_out = self.symbolic_single_sample(node) + else: + node_out = self.symbolic_sample_over_posterior(node) + node_out = self.set_size_and_deterministic(node_out, size, deterministic) + try_to_set_test_value(node_in, node_out, size) + return node_out def rslice(self, name, model=None): """*Dev* - vectorized sampling for named random variable without call to `pytensor.scan`. @@ -1863,14 +1870,15 @@ def rslice(self, name, model=None): """ model = modelcontext(model) - for random, ordering in zip(self.symbolic_randoms, self.collect("ordering")): - if name in ordering: - _name, slc, shape, dtype = ordering[name] - found = random[..., slc].reshape((random.shape[0], *shape)).astype(dtype) - found.name = name + "_vi_random_slice" - break - else: - raise KeyError(f"{name!r} not found") + with model: + for random, ordering in zip(self.symbolic_randoms, self.collect("ordering")): + if name in ordering: + _name, slc, shape, dtype = ordering[name] + found = random[..., slc].reshape((random.shape[0], *shape)).astype(dtype) + found.name = name + "_vi_random_slice" + break + else: + raise KeyError(f"{name!r} not found") return found @node_property @@ -1879,15 +1887,16 @@ def sample_dict_fn(self): def inner(draws=100, *, model=None, random_seed: SeedSequenceSeed = None): model = modelcontext(model) - orderings = self._variational_orderings(model) - approx_var_names = sorted(orderings.keys()) - approx_samples = self._draw_variational_samples( - model, approx_var_names, draws, s, random_seed - ) - forward_samples = self._draw_forward_samples( - model, approx_samples, approx_var_names, draws, random_seed - ) - return {**approx_samples, **forward_samples} + with model: + orderings = self._variational_orderings(model) + approx_var_names = sorted(orderings.keys()) + approx_samples = self._draw_variational_samples( + model, approx_var_names, draws, s, random_seed + ) + forward_samples = self._draw_forward_samples( + model, approx_samples, approx_var_names, draws, random_seed + ) + return {**approx_samples, **forward_samples} return inner @@ -1922,37 +1931,38 @@ def sample( model = modelcontext(model) - if random_seed is not None: - (random_seed,) = _get_seeds_per_chain(random_seed, 1) - samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed) - spec = self._build_trace_spec(model, samples) - - from collections import OrderedDict - - default_point = model.initial_point() - value_var_names = [var.name for var in model.value_vars] - points = ( - OrderedDict( - ( - name, - np.asarray(samples[name][i]) - if name in samples and len(samples[name]) > i - else np.asarray(spec.test_point.get(name, default_point[name])), + with model: + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + samples: dict = self.sample_dict_fn(draws, model=model, random_seed=random_seed) + spec = self._build_trace_spec(model, samples) + + from collections import OrderedDict + + default_point = model.initial_point() + value_var_names = [var.name for var in model.value_vars] + points = ( + OrderedDict( + ( + name, + np.asarray(samples[name][i]) + if name in samples and len(samples[name]) > i + else np.asarray(spec.test_point.get(name, default_point[name])), + ) + for name in value_var_names ) - for name in value_var_names + for i in range(draws) ) - for i in range(draws) - ) - trace = NDArray( - model=model, - ) - try: - trace.setup(draws=draws, chain=0) - for point in points: - trace.record(point) - finally: - trace.close() + trace = NDArray( + model=model, + ) + try: + trace.setup(draws=draws, chain=0) + for point in points: + trace.record(point) + finally: + trace.close() multi_trace = MultiTrace([trace]) if not return_inferencedata: diff --git a/tests/variational/test_approximations.py b/tests/variational/test_approximations.py index ab30e9bbe3..1bb983ba6b 100644 --- a/tests/variational/test_approximations.py +++ b/tests/variational/test_approximations.py @@ -55,8 +55,9 @@ def test_elbo(): # Create variational gradient tensor mean_field = MeanField(model=model) - with pytensor.config.change_flags(compute_test_value="off"): - elbo = -pm.operators.KL(mean_field)()(10000) + with model: + with pytensor.config.change_flags(compute_test_value="off"): + elbo = -pm.operators.KL(mean_field)()(10000) mean_field.shared_params["mu"].set_value(post_mu) mean_field.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1)) @@ -113,9 +114,8 @@ def test_scale_cost_to_minibatch_works(aux_total_size): assert not mean_field_2.scale_cost_to_minibatch mean_field_2.shared_params["mu"].set_value(post_mu) mean_field_2.shared_params["rho"].set_value(np.log(np.exp(post_sigma) - 1)) - - with pytensor.config.change_flags(compute_test_value="off"): - elbo_via_total_size_unscaled = -pm.operators.KL(mean_field_2)()(10000) + with pytensor.config.change_flags(compute_test_value="off"): + elbo_via_total_size_unscaled = -pm.operators.KL(mean_field_2)()(10000) np.testing.assert_allclose( elbo_via_total_size_unscaled.eval(), diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index 41460f14d5..e4f84d665f 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -270,9 +270,9 @@ def test_logq_mini_2_sample_2_var(parametric_grouped_approxes, three_var_model): def test_logq_globals(three_var_approx, three_var_model): - if not three_var_approx.has_logq: - pytest.skip(f"{three_var_approx} does not implement logq") with three_var_model: + if not three_var_approx.has_logq: + pytest.skip(f"{three_var_approx} does not implement logq") approx = three_var_approx logq, symbolic_logq = approx.set_size_and_deterministic( [approx.logq, approx.symbolic_logq], 1, 0 From 135d47234cb6df20c8ea1b75bc04aaabadbe2d9f Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 2 Nov 2025 04:13:04 +0100 Subject: [PATCH 08/11] Remove caching logic --- pymc/variational/approximations.py | 37 +++--- pymc/variational/opvi.py | 175 ++++++++-------------------- pymc/variational/stein.py | 19 +-- tests/variational/test_inference.py | 24 ++-- tests/variational/test_opvi.py | 6 + 5 files changed, 94 insertions(+), 167 deletions(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index 6f012333a9..2943db0544 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property + import numpy as np import pytensor @@ -32,7 +34,6 @@ Group, NotImplementedInference, _known_scan_ignored_inputs, - node_property, ) __all__ = ["Empirical", "FullRank", "MeanField", "sample_approx"] @@ -52,20 +53,20 @@ class MeanFieldGroup(Group): short_name = "mean_field" alias_names = frozenset(["mf"]) - @node_property + @cached_property def mean(self): return self.params_dict["mu"] - @node_property + @cached_property def rho(self): return self.params_dict["rho"] - @node_property + @cached_property def cov(self): var = rho2sigma(self.rho) ** 2 return pt.diag(var) - @node_property + @cached_property def std(self): return rho2sigma(self.rho) @@ -106,14 +107,14 @@ def create_shared_params(self, start=None, start_sigma=None): "rho": pytensor.shared(pm.floatX(rho), "rho"), } - @node_property + @cached_property def symbolic_random(self): initial = self.symbolic_initial sigma = self.std mu = self.mean return sigma * initial + mu - @node_property + @cached_property def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial std = rho2sigma(self.rho) @@ -157,7 +158,7 @@ def create_shared_params(self, start=None): L_tril = np.eye(n)[np.tril_indices(n)].astype(pytensor.config.floatX) return {"mu": pytensor.shared(start, "mu"), "L_tril": pytensor.shared(L_tril, "L_tril")} - @node_property + @cached_property def L(self): L = pt.zeros((self.ddim, self.ddim)) L = pt.set_subtensor(L[self.tril_indices], self.params_dict["L_tril"]) @@ -165,16 +166,16 @@ def L(self): L = pt.set_subtensor(Ld, rho2sigma(Ld)) return L - @node_property + @cached_property def mean(self): return self.params_dict["mu"] - @node_property + @cached_property def cov(self): L = self.L return L.dot(L.T) - @node_property + @cached_property def std(self): return pt.sqrt(pt.diag(self.cov)) @@ -187,7 +188,7 @@ def num_tril_entries(self): def tril_indices(self): return np.tril_indices(self.ddim) - @node_property + @cached_property def symbolic_logq_not_scaled(self): z0 = self.symbolic_initial diag = pt.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1) @@ -196,7 +197,7 @@ def symbolic_logq_not_scaled(self): logq = quaddist - logdet return logq.sum(range(1, logq.ndim)) - @node_property + @cached_property def symbolic_random(self): initial = self.symbolic_initial L = self.L @@ -300,24 +301,24 @@ def _new_initial(self, size, deterministic, more_replacements=None): else: return self.histogram[self.randidx(size)] - @property + @cached_property def symbolic_random(self): return self.symbolic_initial - @property + @cached_property def histogram(self): return self.params_dict["histogram"] - @node_property + @cached_property def mean(self): return self.histogram.mean(0) - @node_property + @cached_property def cov(self): x = self.histogram - self.mean return x.T.dot(x) / pm.floatX(self.histogram.shape[0]) - @node_property + @cached_property def std(self): return pt.sqrt(pt.diag(self.cov)) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 7d88f7aaa5..4314a1b1de 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -53,6 +53,7 @@ import warnings from dataclasses import dataclass +from functools import cached_property from typing import Any, overload import numpy as np @@ -79,12 +80,7 @@ find_rng_nodes, reseed_rngs, ) -from pymc.util import ( - RandomState, - WithMemoization, - _get_seeds_per_chain, - makeiter, -) +from pymc.util import RandomState, _get_seeds_per_chain, makeiter from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types @@ -143,47 +139,6 @@ def inner(*args, **kwargs): return wrap -def node_property(f): - """Wrap method to accessible tensor.""" - from collections import defaultdict - - from cachetools import LRUCache, cachedmethod - - def self_cache_fn(f_name): - def cf(self): - return self.__dict__.setdefault("_cache", defaultdict(lambda: LRUCache(128)))[f_name] - - return cf - - def cache_key_with_model(*args, **kwargs): - """Cache key that includes the current model context.""" - from pymc.util import hash_key - - try: - model = modelcontext(None) - # Include the model in the cache key to avoid using cached values - # from a different model context - return ( - *tuple(hash_key(*args, **kwargs)), - id(model), - ) - except TypeError: - # If no model context, just use regular hash_key - return tuple(hash_key(*args, **kwargs)) - - if isinstance(f, str): - - def wrapper(fn): - ff = append_name(f)(fn) - f_ = pytensor.config.change_flags(compute_test_value="off")(ff) - return property(cachedmethod(self_cache_fn(f_.__name__), key=cache_key_with_model)(f_)) - - return wrapper - else: - f_ = pytensor.config.change_flags(compute_test_value="off")(f) - return property(cachedmethod(self_cache_fn(f_.__name__), key=cache_key_with_model)(f_)) - - @pytensor.config.change_flags(compute_test_value="ignore") def try_to_set_test_value(node_in, node_out, s): _s = s @@ -613,7 +568,7 @@ def from_function(cls, f): return obj -class Group(WithMemoization): +class Group: R"""**Base class for grouping variables in VI**. Grouped Approximation is used for modelling mutual dependencies @@ -984,32 +939,10 @@ def __init_group__(self, group): ) start_idx += size - def __setstate__(self, state): - """Restore state after unpickling and clear cache.""" - super().__setstate__(state) - # Clear cached state after unpickling since cached values may reference - # variables from a different model context - self._clear_cached_state() - # Recreate _kwargs if it was deleted by _finalize_init, needed for rebuild - if not hasattr(self, "_kwargs"): - self._kwargs = {} - def _finalize_init(self): """*Dev* - clean up after init.""" del self._kwargs - def _clear_cached_state(self, *, reset_shared=False): - """Reset cached structures that depend on the current model context.""" - if hasattr(self, "_cache"): - del self._cache - self.replacements = collections.OrderedDict() - self.ordering = collections.OrderedDict() - for attr in ("symbolic_initial", "input"): - if attr in self.__dict__: - delattr(self, attr) - if reset_shared: - self.shared_params = None - @property def params_dict(self): # prefixed are correctly reshaped @@ -1046,14 +979,6 @@ def _new_initial_shape(self, size, dim, more_replacements=None): """ return pt.stack([size, dim]) - @node_property - def ndim(self): - return self.ddim - - @property - def ddim(self): - return sum(s.stop - s.start for _, s, _, _ in self.ordering.values()) - def _new_initial(self, size, deterministic, more_replacements=None): """*Dev* - allocates new initial random generator. @@ -1099,7 +1024,15 @@ def _new_initial(self, size, deterministic, more_replacements=None): initial = pt.switch(deterministic, pt.ones(shape, dtype) * dist_map, sample) return initial - @node_property + @property + def ndim(self): + return self.ddim + + @property + def ddim(self): + return sum(s.stop - s.start for _, s, _, _ in self.ordering.values()) + + @cached_property def symbolic_random(self): """*Dev* - abstract node that takes `self.symbolic_initial` and creates approximate posterior that is parametrized with `self.params_dict`. @@ -1206,7 +1139,7 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None) initial = graph_replace(initial, more_replacements, strict=False) return {self.symbolic_initial: initial} - @node_property + @cached_property def symbolic_normalizing_constant(self): """*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`.""" t = self.to_flat_input( @@ -1225,45 +1158,37 @@ def symbolic_normalizing_constant(self): t = self.symbolic_single_sample(t) return pm.floatX(t) - @node_property + @property def symbolic_logq_not_scaled(self): """*Dev* - symbolically computed logq for `self.symbolic_random` computations can be more efficient since all is known beforehand including `self.symbolic_random`.""" raise NotImplementedError # shape (s,) - @node_property + @cached_property def symbolic_logq(self): """*Dev* - correctly scaled `self.symbolic_logq_not_scaled`.""" return self.symbolic_logq_not_scaled - @node_property + @cached_property def logq(self): """*Dev* - Monte Carlo estimate for group `logQ`.""" return self.symbolic_logq.mean(0) - @node_property + @cached_property def logq_norm(self): """*Dev* - Monte Carlo estimate for group `logQ` normalized.""" return self.logq / self.symbolic_normalizing_constant - def __str__(self): - """Return a string representation for the object.""" - if self.group is None: - shp = "undefined" - else: - shp = str(self.ddim) - return f"{self.__class__.__name__}[{shp}]" - - @node_property + @property def std(self) -> pt.TensorVariable: """Return the standard deviation of the latent variables as an unstructured 1-dimensional tensor variable.""" raise NotImplementedError() - @node_property + @property def cov(self) -> pt.TensorVariable: """Return the covariance between the latent variables as an unstructured 2-dimensional tensor variable.""" raise NotImplementedError() - @node_property + @property def mean(self) -> pt.TensorVariable: """Return the mean of the latent variables as an unstructured 1-dimensional tensor variable.""" raise NotImplementedError() @@ -1328,10 +1253,14 @@ def _refresh_group_for_model(group, model, group_vars=None): if not hasattr(group, "_kwargs"): group._kwargs = {} original_user_params = group.user_params - group._clear_cached_state(reset_shared=True) group.user_params = None group._user_params = None group.group = None + group.replacements = collections.OrderedDict() + group.ordering = collections.OrderedDict() + group.__dict__.pop("symbolic_initial", None) + group.__dict__.pop("input", None) + group.shared_params = None group.__init_group__(list(group_vars)) if original_user_params is not None: group.user_params = original_user_params @@ -1384,7 +1313,7 @@ class TraceSpec: test_point: collections.OrderedDict -class Approximation(WithMemoization): +class Approximation: """**Wrapper for grouped approximations**. Wraps list of groups, creates an Approximation instance that collects @@ -1411,14 +1340,8 @@ class Approximation(WithMemoization): """ def __setstate__(self, state): - """Restore state after unpickling and clear cache.""" - super().__setstate__(state) - # Clear cache after unpickling since cached values may reference - # variables from a different model context - # _cache is removed during pickling by WithMemoization.__getstate__, - # so it shouldn't exist after unpickling, but ensure it's deleted if it does - if hasattr(self, "_cache"): - del self._cache + """Restore state after unpickling.""" + self.__dict__.update(state) def __init__(self, groups, model=None): self._scale_cost_to_minibatch = pytensor.shared(np.int8(1)) @@ -1610,7 +1533,7 @@ def scale_cost_to_minibatch(self): def scale_cost_to_minibatch(self, value): self._scale_cost_to_minibatch.set_value(np.int8(bool(value))) - @node_property + @cached_property def symbolic_normalizing_constant(self): """*Dev* - normalizing constant for `self.logq`, scales it to `minibatch_size` instead of `total_size`. @@ -1631,91 +1554,91 @@ def symbolic_normalizing_constant(self): t = pt.switch(self._scale_cost_to_minibatch, t, pt.constant(1, dtype=t.dtype)) return pm.floatX(t) - @node_property + @cached_property def symbolic_logq(self): """*Dev* - collects `symbolic_logq` for all groups.""" return pt.add(*self.collect("symbolic_logq")) - @node_property + @cached_property def logq(self): """*Dev* - collects `logQ` for all groups.""" return pt.add(*self.collect("logq")) - @node_property + @cached_property def logq_norm(self): """*Dev* - collects `logQ` for all groups and normalizes it.""" return self.logq / self.symbolic_normalizing_constant - @node_property + @cached_property def _sized_symbolic_varlogp_and_datalogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" model = modelcontext(None) varlogp_s, datalogp_s = self.symbolic_sample_over_posterior([model.varlogp, model.datalogp]) return varlogp_s, datalogp_s # both shape (s,) - @node_property + @cached_property def sized_symbolic_varlogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" return self._sized_symbolic_varlogp_and_datalogp[0] # shape (s,) - @node_property + @cached_property def sized_symbolic_datalogp(self): """*Dev* - computes sampled data term from model via `pytensor.scan`.""" return self._sized_symbolic_varlogp_and_datalogp[1] # shape (s,) - @node_property + @cached_property def sized_symbolic_logp(self): """*Dev* - computes sampled logP from model via `pytensor.scan`.""" return self.sized_symbolic_varlogp + self.sized_symbolic_datalogp # shape (s,) - @node_property + @cached_property def logp(self): """*Dev* - computes :math:`E_{q}(logP)` from model via `pytensor.scan` that can be optimized later.""" return self.varlogp + self.datalogp - @node_property + @cached_property def varlogp(self): """*Dev* - computes :math:`E_{q}(prior term)` from model via `pytensor.scan` that can be optimized later.""" return self.sized_symbolic_varlogp.mean(0) - @node_property + @cached_property def datalogp(self): """*Dev* - computes :math:`E_{q}(data term)` from model via `pytensor.scan` that can be optimized later.""" return self.sized_symbolic_datalogp.mean(0) - @node_property + @cached_property def _single_symbolic_varlogp_and_datalogp(self): """*Dev* - computes sampled prior term from model via `pytensor.scan`.""" model = modelcontext(None) varlogp, datalogp = self.symbolic_single_sample([model.varlogp, model.datalogp]) return varlogp, datalogp - @node_property + @cached_property def single_symbolic_varlogp(self): """*Dev* - for single MC sample estimate of :math:`E_{q}(prior term)` `pytensor.scan` is not needed and code can be optimized.""" return self._single_symbolic_varlogp_and_datalogp[0] - @node_property + @cached_property def single_symbolic_datalogp(self): """*Dev* - for single MC sample estimate of :math:`E_{q}(data term)` `pytensor.scan` is not needed and code can be optimized.""" return self._single_symbolic_varlogp_and_datalogp[1] - @node_property + @cached_property def single_symbolic_logp(self): """*Dev* - for single MC sample estimate of :math:`E_{q}(logP)` `pytensor.scan` is not needed and code can be optimized.""" return self.single_symbolic_datalogp + self.single_symbolic_varlogp - @node_property + @cached_property def logp_norm(self): """*Dev* - normalized :math:`E_{q}(logP)`.""" return self.logp / self.symbolic_normalizing_constant - @node_property + @cached_property def varlogp_norm(self): """*Dev* - normalized :math:`E_{q}(prior term)`.""" return self.varlogp / self.symbolic_normalizing_constant - @node_property + @cached_property def datalogp_norm(self): """*Dev* - normalized :math:`E_{q}(data term)`.""" return self.datalogp / self.symbolic_normalizing_constant @@ -1881,7 +1804,7 @@ def rslice(self, name, model=None): raise KeyError(f"{name!r} not found") return found - @node_property + @property def sample_dict_fn(self): s = pt.iscalar() @@ -1978,7 +1901,7 @@ def ndim(self): def ddim(self): return sum(self.collect("ddim")) - @node_property + @cached_property def symbolic_random(self): return pt.concatenate(self.collect("symbolic_random"), axis=-1) @@ -1998,7 +1921,7 @@ def all_histograms(self): def any_histograms(self): return any(isinstance(g, pm.approximations.EmpiricalGroup) for g in self.groups) - @node_property + @property def joint_histogram(self): if not self.all_histograms: raise VariationalInferenceError("%s does not consist of all Empirical approximations") diff --git a/pymc/variational/stein.py b/pymc/variational/stein.py index 0534bb6fa4..1bc9360c0f 100644 --- a/pymc/variational/stein.py +++ b/pymc/variational/stein.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property + import pytensor.tensor as pt from pytensor.graph.replace import graph_replace from pymc.pytensorf import floatX from pymc.util import WithMemoization, locally_cachedmethod -from pymc.variational.opvi import node_property from pymc.variational.test_functions import rbf __all__ = ["Stein"] @@ -38,14 +39,14 @@ def input_joint_matrix(self): else: return self.approx.symbolic_random - @node_property + @cached_property def approx_symbolic_matrices(self): if self.use_histogram: return self.approx.collect("histogram") else: return self.approx.symbolic_randoms - @node_property + @cached_property def dlogp(self): logp = self.logp_norm.sum() grad = pt.grad(logp, self.approx_symbolic_matrices) @@ -55,34 +56,34 @@ def flatten2(tensor): return pt.concatenate(list(map(flatten2, grad)), -1) - @node_property + @cached_property def grad(self): n = floatX(self.input_joint_matrix.shape[0]) temperature = self.temperature svgd_grad = self.density_part_grad / temperature + self.repulsive_part_grad return svgd_grad / n - @node_property + @cached_property def density_part_grad(self): Kxy = self.Kxy dlogpdx = self.dlogp return pt.dot(Kxy, dlogpdx) - @node_property + @cached_property def repulsive_part_grad(self): t = self.approx.symbolic_normalizing_constant dxkxy = self.dxkxy return dxkxy / t - @property + @cached_property def Kxy(self): return self._kernel()[0] - @property + @cached_property def dxkxy(self): return self._kernel()[1] - @node_property + @cached_property def logp_norm(self): sized_symbolic_logp = self.approx.sized_symbolic_logp if self.use_histogram: diff --git a/tests/variational/test_inference.py b/tests/variational/test_inference.py index bdd7a5ed79..ee442e30ab 100644 --- a/tests/variational/test_inference.py +++ b/tests/variational/test_inference.py @@ -358,23 +358,19 @@ def test_var_replacement(): assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11,) -def test_clear_cache(): +@pytest.mark.parametrize( + "inference_cls", + [ADVI, FullRankADVI], +) +def test_advi_pickle(inference_cls): with pm.Model() as model: pm.Normal("n", 0, 1) - inference = ADVI() + inference = inference_cls() inference.fit(n=10) - assert any(len(c) != 0 for c in inference.approx._cache.values()) - inference.approx._cache.clear() - # should not be cleared at this call - assert all(len(c) == 0 for c in inference.approx._cache.values()) - new_a = cloudpickle.loads(cloudpickle.dumps(inference.approx)) - assert not hasattr(new_a, "_cache") - with model: - inference_new = pm.KLqp(new_a) - inference_new.fit(n=10) - assert any(len(c) != 0 for c in inference_new.approx._cache.values()) - inference_new.approx._cache.clear() - assert all(len(c) == 0 for c in inference_new.approx._cache.values()) + serialized = cloudpickle.dumps(inference.approx) + new_approx = cloudpickle.loads(serialized) + inference_new = pm.KLqp(new_approx) + inference_new.fit(n=10) def test_fit_data(inference, fit_kwargs, simple_model_data, simple_model): diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index e4f84d665f..e4200b1424 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -291,6 +291,12 @@ def test_logq_globals(three_var_approx, three_var_model): assert es.shape == (2,) +def test_model_property_emits_deprecation(three_var_approx, three_var_model): + with three_var_model: + with pytest.warns(DeprecationWarning, match="`model` field is deprecated"): + _ = three_var_approx.model + + def test_symbolic_normalizing_constant_no_rvs(): # Test that RVs aren't included in the graph of symbolic_normalizing_constant rng = np.random.default_rng() From ae6b9e14c6258bcf18a3e863373a40881fea7b0d Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 2 Nov 2025 05:11:27 +0100 Subject: [PATCH 09/11] Cleanup --- pymc/variational/opvi.py | 122 +++++---------------------------------- 1 file changed, 14 insertions(+), 108 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 4314a1b1de..790d1f29da 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -948,10 +948,9 @@ def params_dict(self): # prefixed are correctly reshaped if self._user_params is not None: return self._user_params - else: - if self.shared_params is None and self.group is not None: - _refresh_group_for_model(self, modelcontext(None)) - return self.shared_params + if self.shared_params is None: + raise ParametrizationError("Group parameters have not been initialized") + return self.shared_params @property def params(self): @@ -1225,88 +1224,6 @@ def std_data(self) -> xarray.Dataset: group_for_short_name = Group.group_for_short_name -def _map_group_vars_to_model(group_vars, model): - if not group_vars: - return [] - var_name_map = {var.name: var for var in model.free_RVs} - mapped = [] - for var in group_vars: - if var in model.free_RVs: - mapped.append(var) - else: - mapped_var = var_name_map.get(var.name) - if mapped_var is not None: - mapped.append(mapped_var) - return mapped - - -def _refresh_group_for_model(group, model, group_vars=None): - if group_vars is None: - group_vars = group.group or [] - mapped_group = _map_group_vars_to_model(group_vars, model) - if mapped_group: - group_vars = mapped_group - if not group_vars: - group.group = group_vars - return group.group - if group.shared_params is None: - if not hasattr(group, "_kwargs"): - group._kwargs = {} - original_user_params = group.user_params - group.user_params = None - group._user_params = None - group.group = None - group.replacements = collections.OrderedDict() - group.ordering = collections.OrderedDict() - group.__dict__.pop("symbolic_initial", None) - group.__dict__.pop("input", None) - group.shared_params = None - group.__init_group__(list(group_vars)) - if original_user_params is not None: - group.user_params = original_user_params - else: - group.group = list(group_vars) - if "symbolic_initial" not in group.__dict__: - group.symbolic_initial = group._initial_type( - group.__class__.__name__ + "_symbolic_initial_tensor" - ) - if "input" not in group.__dict__: - group.input = group._input_type(group.__class__.__name__ + "_symbolic_input") - _rebuild_group_mappings(group, model) - return group.group - - -def _rebuild_group_mappings(group, model): - if not group.group: - group.replacements = collections.OrderedDict() - group.ordering = collections.OrderedDict() - return - model_initial_point = model.initial_point(0) - replacements = collections.OrderedDict() - ordering = collections.OrderedDict() - start_idx = 0 - for var in group.group: - if var.type.numpy_dtype.name in discrete_types: - raise ParametrizationError(f"Discrete variables are not supported by VI: {var}") - value_var = model.rvs_to_values[var] - test_var = model_initial_point[value_var.name] - shape = test_var.shape - size = test_var.size - dtype = test_var.dtype - vr = group.input[..., start_idx : start_idx + size].reshape(shape).astype(dtype) - vr.name = value_var.name + "_vi_replacement" - replacements[value_var] = vr - ordering[value_var.name] = ( - value_var.name, - slice(start_idx, start_idx + size), - shape, - dtype, - ) - start_idx += size - group.replacements = replacements - group.ordering = ordering - - @dataclass class TraceSpec: sample_vars: list @@ -1356,22 +1273,24 @@ def __init__(self, groups, model=None): if g.group is None: if rest is not None: raise GroupError("More than one group is specified for the rest variables") - else: - rest = g + rest = g else: - final_group = _refresh_group_for_model(g, model) - if set(final_group) & seen: + group_vars = list(g.group) + missing = [var for var in group_vars if var not in model.free_RVs] + if missing: + names = ", ".join(var.name for var in missing) + raise GroupError(f"Variables [{names}] are not part of the provided model") + if set(group_vars) & seen: raise GroupError("Found duplicates in groups") - seen.update(final_group) + seen.update(group_vars) self.groups.append(g) # List iteration to preserve order for reproducibility between runs unseen_free_RVs = [var for var in model.free_RVs if var not in seen] if unseen_free_RVs: if rest is None: raise GroupError("No approximation is specified for the rest variables") - else: - rest.__init_group__(unseen_free_RVs) - self.groups.append(rest) + rest.__init_group__(unseen_free_RVs) + self.groups.append(rest) @property def has_logq(self): @@ -1386,26 +1305,13 @@ def model(self): ) return modelcontext(None) - def _ensure_groups_ready(self, model=None): - try: - model = modelcontext(model) - except TypeError: - return - with model: - for g in self.groups: - _refresh_group_for_model(g, model) - def collect(self, item): - model = modelcontext(None) - self._ensure_groups_ready(model=model) return [getattr(g, item) for g in self.groups] def _variational_orderings(self, model): orderings = collections.OrderedDict() for g in self.groups: - mapped = _refresh_group_for_model(g, model) - if mapped: - orderings.update(g.ordering) + orderings.update(g.ordering) return orderings def _draw_variational_samples(self, model, names, draws, size_sym, random_seed): From adeca863cd1b1bad8650646d59faa2212c7d8fb4 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 2 Nov 2025 05:19:55 +0100 Subject: [PATCH 10/11] Cleanup --- pymc/variational/opvi.py | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 790d1f29da..1c19908479 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -874,35 +874,8 @@ def __init_group__(self, group): raise GroupError("Got empty group") model = modelcontext(None) - # If self.group is already set (from unpickling), we might need to rebuild it - # to map old variables to new ones in the current model context - if self.group is not None: - # Check if any variables in self.group don't belong to the current model - # If so, rebuild the group by matching variable names - needs_rebuild = False - for var in self.group: - # Check if variable is in the current model's free_RVs - if var not in model.free_RVs: - needs_rebuild = True - break - - if needs_rebuild: - # Rebuild group by matching variable names - var_name_map = {var.name: var for var in model.free_RVs} - new_group = [] - for old_var in self.group: - if old_var.name in var_name_map: - new_group.append(var_name_map[old_var.name]) - else: - raise ValueError( - f"Variable '{old_var.name}' from unpickled group not found in current model. " - f"Available variables: {list(var_name_map.keys())}" - ) - self.group = new_group - if self.group is None: - # delayed init - self.group = group + self.group = list(group) self.symbolic_initial = self._initial_type( self.__class__.__name__ + "_symbolic_initial_tensor" From b4495b7f60a47756ec96c5045df09a5bf1b7150d Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 2 Nov 2025 06:04:59 +0100 Subject: [PATCH 11/11] Add test that sampling with a different model works --- pymc/variational/opvi.py | 12 ++++++------ tests/variational/test_opvi.py | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 1c19908479..7d480b0fdc 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -80,7 +80,7 @@ find_rng_nodes, reseed_rngs, ) -from pymc.util import RandomState, _get_seeds_per_chain, makeiter +from pymc.util import RandomState, _get_seeds_per_chain, makeiter, point_wrapper from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types @@ -1321,13 +1321,13 @@ def _draw_forward_samples(self, model, approx_samples, approx_names, draws, rand random_seed=random_seed, ) approx_value_vars = [model.rvs_to_values[var] for var in approx_vars] + input_values = {var.name: approx_samples[var.name] for var in approx_value_vars} + wrapped_sampler = point_wrapper(sampler_fn) + stacked = {name: [] for name in forward_names} for i in range(draws): - inputs = { - value_var.name: approx_samples[value_var.name][i] - for value_var in approx_value_vars - } - raw = sampler_fn(**inputs) + inputs = {name: values[i] for name, values in input_values.items()} + raw = wrapped_sampler(**inputs) if not isinstance(raw, list | tuple): raw = [raw] for name, value in zip(forward_names, raw): diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index e4200b1424..6ecb68de17 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -313,3 +313,27 @@ def test_symbolic_normalizing_constant_no_rvs(): # Access the property again to test it doesn't require model context after first access assert_no_rvs(symbolic_normalizing) + + +def test_sample_additional_vars(three_var_approx, three_var_model): + with pm.Model() as extended_model: + one = pm.HalfNormal("one", size=(10, 2)) + two = pm.Normal("two", size=(10,)) + three = pm.Normal("three", size=(10, 1, 2)) + four = pm.Normal("four", mu=two, sigma=1, size=(10,)) + five = pm.Deterministic("five", four.sum()) + pm.Normal("six", mu=five, sigma=1) + + with extended_model: + idata = three_var_approx.sample(20) + + posterior = idata.posterior + + varnames = set(posterior.data_vars) + assert {"one", "two", "three"}.issubset(varnames) + assert {"four", "five", "six"}.issubset(varnames) + assert posterior.sizes["draw"] == 20 + assert posterior.sizes["chain"] == 1 + assert posterior["four"].shape == (1, 20, 10) + assert posterior["five"].shape == (1, 20) + assert posterior["six"].shape == (1, 20)