55import pytensor .tensor as pt
66import xarray
77
8- from better_optimize import minimize
8+ from better_optimize import basinhopping , minimize
99from better_optimize .constants import minimize_method
1010from pymc import DictToArrayBijection , Model , join_nonshared_inputs
1111from pymc .backends .arviz import (
3131def fit_dadvi (
3232 model : Model | None = None ,
3333 n_fixed_draws : int = 30 ,
34- random_seed : RandomSeed = None ,
3534 n_draws : int = 1000 ,
3635 include_transformed : bool = False ,
3736 optimizer_method : minimize_method = "trust-ncg" ,
@@ -40,7 +39,9 @@ def fit_dadvi(
4039 use_hess : bool | None = None ,
4140 gradient_backend : str = "pytensor" ,
4241 compile_kwargs : dict | None = None ,
43- ** minimize_kwargs ,
42+ random_seed : RandomSeed = None ,
43+ progressbar : bool = True ,
44+ ** optimizer_kwargs ,
4445) -> az .InferenceData :
4546 """
4647 Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
@@ -79,10 +80,6 @@ def fit_dadvi(
7980 compile_kwargs: dict, optional
8081 Additional keyword arguments to pass to `pytensor.function`
8182
82- minimize_kwargs:
83- Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
84- that function for details.
85-
8683 use_grad: bool, optional
8784 If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
8885
@@ -93,6 +90,13 @@ def fit_dadvi(
9390 If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
9491 computation can be slow and memory-intensive if there are many parameters.
9592
93+ progressbar: bool
94+ Whether or not to show a progress bar during optimization. Default is True.
95+
96+ optimizer_kwargs:
97+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
98+ that function for details.
99+
96100 Returns
97101 -------
98102 :class:`~arviz.InferenceData`
@@ -105,6 +109,16 @@ def fit_dadvi(
105109 """
106110
107111 model = pymc .modelcontext (model ) if model is None else model
112+ do_basinhopping = optimizer_method == "basinhopping"
113+ minimizer_kwargs = optimizer_kwargs .pop ("minimizer_kwargs" , {})
114+
115+ if do_basinhopping :
116+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
117+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
118+ # if one isn't provided.
119+
120+ optimizer_method = minimizer_kwargs .pop ("method" , "L-BFGS-B" )
121+ minimizer_kwargs ["method" ] = optimizer_method
108122
109123 initial_point_dict = model .initial_point ()
110124 initial_point = DictToArrayBijection .map (initial_point_dict )
@@ -145,14 +159,34 @@ def fit_dadvi(
145159 )
146160
147161 dadvi_initial_point = DictToArrayBijection .map (dadvi_initial_point )
148-
149- result = minimize (
150- f = f_fused ,
151- x0 = dadvi_initial_point .data ,
152- method = optimizer_method ,
153- hessp = f_hessp ,
154- ** minimize_kwargs ,
155- )
162+ args = optimizer_kwargs .pop ("args" , ())
163+
164+ if do_basinhopping :
165+ if "args" not in minimizer_kwargs :
166+ minimizer_kwargs ["args" ] = args
167+ if "hessp" not in minimizer_kwargs :
168+ minimizer_kwargs ["hessp" ] = f_hessp
169+ if "method" not in minimizer_kwargs :
170+ minimizer_kwargs ["method" ] = optimizer_method
171+
172+ result = basinhopping (
173+ func = f_fused ,
174+ x0 = dadvi_initial_point .data ,
175+ progressbar = progressbar ,
176+ minimizer_kwargs = minimizer_kwargs ,
177+ ** optimizer_kwargs ,
178+ )
179+
180+ else :
181+ result = minimize (
182+ f = f_fused ,
183+ x0 = dadvi_initial_point .data ,
184+ args = args ,
185+ method = optimizer_method ,
186+ hessp = f_hessp ,
187+ progressbar = progressbar ,
188+ ** optimizer_kwargs ,
189+ )
156190
157191 raveled_optimized = RaveledVars (result .x , dadvi_initial_point .point_map_info )
158192
@@ -166,7 +200,9 @@ def fit_dadvi(
166200 draws = opt_means + draws_raw * np .exp (opt_log_sds )
167201 draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
168202
169- idata = dadvi_result_to_idata (draws_arviz , model , include_transformed = include_transformed )
203+ idata = dadvi_result_to_idata (
204+ draws_arviz , model , include_transformed = include_transformed , progressbar = progressbar
205+ )
170206
171207 var_name_to_model_var = {f"{ var_name } _mu" : var_name for var_name in initial_point_dict .keys ()}
172208 var_name_to_model_var .update (
@@ -253,6 +289,7 @@ def dadvi_result_to_idata(
253289 unstacked_draws : xarray .Dataset ,
254290 model : Model ,
255291 include_transformed : bool = False ,
292+ progressbar : bool = True ,
256293):
257294 """
258295 Transforms the unconstrained draws back into the constrained space.
@@ -271,6 +308,9 @@ def dadvi_result_to_idata(
271308 include_transformed: bool
272309 Whether or not to keep the unconstrained variables in the output.
273310
311+ progressbar: bool
312+ Whether or not to show a progress bar during the transformation. Default is True.
313+
274314 Returns
275315 -------
276316 :class:`~arviz.InferenceData`
@@ -292,6 +332,7 @@ def dadvi_result_to_idata(
292332 output_var_names = [x .name for x in vars_to_sample ],
293333 coords = coords ,
294334 dims = dims ,
335+ progressbar = progressbar ,
295336 )
296337
297338 constrained_names = [
0 commit comments