1313 apply_function_over_dataset ,
1414 coords_and_dims_for_inferencedata ,
1515)
16+ from pymc .blocking import RaveledVars
1617from pymc .util import RandomSeed , get_default_varnames
1718from pytensor .tensor .variable import TensorVariable
1819
20+ from pymc_extras .inference .laplace_approx .idata import (
21+ add_data_to_inference_data ,
22+ add_optimizer_result_to_inference_data ,
23+ )
1924from pymc_extras .inference .laplace_approx .laplace import unstack_laplace_draws
2025from pymc_extras .inference .laplace_approx .scipy_interface import (
21- _compile_functions_for_scipy_optimize ,
26+ scipy_optimize_funcs_from_loss ,
27+ set_optimizer_function_defaults ,
2228)
2329
2430
@@ -29,64 +35,63 @@ def fit_dadvi(
2935 n_draws : int = 1000 ,
3036 keep_untransformed : bool = False ,
3137 optimizer_method : minimize_method = "trust-ncg" ,
32- use_grad : bool = True ,
33- use_hessp : bool = True ,
34- use_hess : bool = False ,
38+ use_grad : bool | None = None ,
39+ use_hessp : bool | None = None ,
40+ use_hess : bool | None = None ,
41+ gradient_backend : str = "pytensor" ,
42+ compile_kwargs : dict | None = None ,
3543 ** minimize_kwargs ,
3644) -> az .InferenceData :
3745 """
38- Does inference using deterministic ADVI (automatic differentiation
39- variational inference), DADVI for short.
46+ Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
4047
41- For full details see the paper cited in the references:
42- https://www.jmlr.org/papers/v25/23-1015.html
48+ For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
4349
4450 Parameters
4551 ----------
4652 model : pm.Model
4753 The PyMC model to be fit. If None, the current model context is used.
4854
4955 n_fixed_draws : int
50- The number of fixed draws to use for the optimisation. More
51- draws will result in more accurate estimates, but also
52- increase inference time. Usually, the default of 30 is a good
53- tradeoff.between speed and accuracy.
56+ The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
57+ also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
5458
5559 random_seed: int
56- The random seed to use for the fixed draws. Running the optimisation
57- twice with the same seed should arrive at the same result.
60+ The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
61+ the same result.
5862
5963 n_draws: int
6064 The number of draws to return from the variational approximation.
6165
6266 keep_untransformed: bool
63- Whether or not to keep the unconstrained variables (such as
64- logs of positive-constrained parameters) in the output.
67+ Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
68+ output.
6569
6670 optimizer_method: str
67- Which optimization method to use. The function calls
68- ``scipy.optimize.minimize``, so any of the methods there can
69- be used. The default is trust-ncg, which uses second-order
70- information and is generally very reliable. Other methods such
71- as L-BFGS-B might be faster but potentially more brittle and
72- may not converge exactly to the optimum.
71+ Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
72+ can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
73+ Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
74+ the optimum.
75+
76+ gradient_backend: str
77+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
78+
79+ compile_kwargs: dict, optional
80+ Additional keyword arguments to pass to `pytensor.function`
7381
7482 minimize_kwargs:
75- Additional keyword arguments to pass to the
76- ``scipy.optimize.minimize`` function. See the documentation of
83+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
7784 that function for details.
7885
79- use_grad:
80- If True, pass the gradient function to
81- `scipy.optimize.minimize` (where it is referred to as `jac`).
86+ use_grad: bool, optional
87+ If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
8288
83- use_hessp:
89+ use_hessp: bool, optional
8490 If True, pass the hessian vector product to `scipy.optimize.minimize`.
8591
86- use_hess:
87- If True, pass the hessian to `scipy.optimize.minimize`. Note that
88- this is generally not recommended since its computation can be slow
89- and memory-intensive if there are many parameters.
92+ use_hess: bool, optional
93+ If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
94+ computation can be slow and memory-intensive if there are many parameters.
9095
9196 Returns
9297 -------
@@ -95,16 +100,15 @@ def fit_dadvi(
95100
96101 References
97102 ----------
98- Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99- Variational Inference with a Deterministic Objective: Faster, More
100- Accurate, and Even More Black Box. Journal of Machine Learning
101- Research, 25(18), 1–39.
103+ Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
104+ Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
102105 """
103106
104107 model = pymc .modelcontext (model ) if model is None else model
105108
106109 initial_point_dict = model .initial_point ()
107- n_params = DictToArrayBijection .map (initial_point_dict ).data .shape [0 ]
110+ initial_point = DictToArrayBijection .map (initial_point_dict )
111+ n_params = initial_point .data .shape [0 ]
108112
109113 var_params , objective = create_dadvi_graph (
110114 model ,
@@ -113,31 +117,45 @@ def fit_dadvi(
113117 n_params = n_params ,
114118 )
115119
116- f_fused , f_hessp = _compile_functions_for_scipy_optimize (
117- objective ,
118- [var_params ],
119- compute_grad = use_grad ,
120- compute_hessp = use_hessp ,
121- compute_hess = use_hess ,
120+ use_grad , use_hess , use_hessp = set_optimizer_function_defaults (
121+ optimizer_method , use_grad , use_hess , use_hessp
122+ )
123+
124+ f_fused , f_hessp = scipy_optimize_funcs_from_loss (
125+ loss = objective ,
126+ inputs = [var_params ],
127+ initial_point_dict = None ,
128+ use_grad = use_grad ,
129+ use_hessp = use_hessp ,
130+ use_hess = use_hess ,
131+ gradient_backend = gradient_backend ,
132+ compile_kwargs = compile_kwargs ,
133+ inputs_are_flat = True ,
122134 )
123135
124- derivative_kwargs = {}
136+ dadvi_initial_point = {
137+ f"{ var_name } _mu" : np .zeros_like (value ).ravel ()
138+ for var_name , value in initial_point_dict .items ()
139+ }
140+ dadvi_initial_point .update (
141+ {
142+ f"{ var_name } _sigma__log" : np .zeros_like (value ).ravel ()
143+ for var_name , value in initial_point_dict .items ()
144+ }
145+ )
125146
126- if use_grad :
127- derivative_kwargs ["jac" ] = True
128- if use_hessp :
129- derivative_kwargs ["hessp" ] = f_hessp
130- if use_hess :
131- derivative_kwargs ["hess" ] = True
147+ dadvi_initial_point = DictToArrayBijection .map (dadvi_initial_point )
132148
133149 result = minimize (
134- f_fused ,
135- np . zeros ( 2 * n_params ) ,
150+ f = f_fused ,
151+ x0 = dadvi_initial_point . data ,
136152 method = optimizer_method ,
137- ** derivative_kwargs ,
153+ hessp = f_hessp ,
138154 ** minimize_kwargs ,
139155 )
140156
157+ raveled_optimized = RaveledVars (result .x , dadvi_initial_point .point_map_info )
158+
141159 opt_var_params = result .x
142160 opt_means , opt_log_sds = np .split (opt_var_params , 2 )
143161
@@ -148,9 +166,29 @@ def fit_dadvi(
148166 draws = opt_means + draws_raw * np .exp (opt_log_sds )
149167 draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
150168
151- transformed_draws = transform_draws (draws_arviz , model , keep_untransformed = keep_untransformed )
169+ idata = az .InferenceData (
170+ posterior = transform_draws (draws_arviz , model , keep_untransformed = keep_untransformed )
171+ )
172+
173+ var_name_to_model_var = {f"{ var_name } _mu" : var_name for var_name in initial_point_dict .keys ()}
174+ var_name_to_model_var .update (
175+ {f"{ var_name } _sigma__log" : var_name for var_name in initial_point_dict .keys ()}
176+ )
177+
178+ idata = add_optimizer_result_to_inference_data (
179+ idata = idata ,
180+ result = result ,
181+ method = optimizer_method ,
182+ mu = raveled_optimized ,
183+ model = model ,
184+ var_name_to_model_var = var_name_to_model_var ,
185+ )
186+
187+ idata = add_data_to_inference_data (
188+ idata = idata , progressbar = False , model = model , compile_kwargs = compile_kwargs
189+ )
152190
153- return transformed_draws
191+ return idata
154192
155193
156194def create_dadvi_graph (
0 commit comments