@@ -33,7 +33,7 @@ def fit_dadvi(
3333 n_fixed_draws : int = 30 ,
3434 random_seed : RandomSeed = None ,
3535 n_draws : int = 1000 ,
36- keep_untransformed : bool = False ,
36+ include_transformed : bool = False ,
3737 optimizer_method : minimize_method = "trust-ncg" ,
3838 use_grad : bool | None = None ,
3939 use_hessp : bool | None = None ,
@@ -63,7 +63,7 @@ def fit_dadvi(
6363 n_draws: int
6464 The number of draws to return from the variational approximation.
6565
66- keep_untransformed : bool
66+ include_transformed : bool
6767 Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
6868 output.
6969
@@ -166,9 +166,7 @@ def fit_dadvi(
166166 draws = opt_means + draws_raw * np .exp (opt_log_sds )
167167 draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
168168
169- idata = az .InferenceData (
170- posterior = transform_draws (draws_arviz , model , keep_untransformed = keep_untransformed )
171- )
169+ idata = dadvi_result_to_idata (draws_arviz , model , include_transformed = include_transformed )
172170
173171 var_name_to_model_var = {f"{ var_name } _mu" : var_name for var_name in initial_point_dict .keys ()}
174172 var_name_to_model_var .update (
@@ -251,10 +249,10 @@ def create_dadvi_graph(
251249 return var_params , objective
252250
253251
254- def transform_draws (
252+ def dadvi_result_to_idata (
255253 unstacked_draws : xarray .Dataset ,
256254 model : Model ,
257- keep_untransformed : bool = False ,
255+ include_transformed : bool = False ,
258256):
259257 """
260258 Transforms the unconstrained draws back into the constrained space.
@@ -270,7 +268,7 @@ def transform_draws(
270268 n_draws: int
271269 The number of draws to return from the variational approximation.
272270
273- keep_untransformed : bool
271+ include_transformed : bool
274272 Whether or not to keep the unconstrained variables in the output.
275273
276274 Returns
@@ -281,7 +279,7 @@ def transform_draws(
281279
282280 filtered_var_names = model .unobserved_value_vars
283281 vars_to_sample = list (
284- get_default_varnames (filtered_var_names , include_transformed = keep_untransformed )
282+ get_default_varnames (filtered_var_names , include_transformed = include_transformed )
285283 )
286284 fn = pytensor .function (model .value_vars , vars_to_sample )
287285 point_func = PointFunc (fn )
@@ -296,4 +294,17 @@ def transform_draws(
296294 dims = dims ,
297295 )
298296
299- return transformed_result
297+ constrained_names = [
298+ x .name for x in get_default_varnames (model .unobserved_value_vars , include_transformed = False )
299+ ]
300+ all_varnames = [
301+ x .name for x in get_default_varnames (model .unobserved_value_vars , include_transformed = True )
302+ ]
303+ unconstrained_names = set (all_varnames ) - set (constrained_names )
304+
305+ idata = az .InferenceData (posterior = transformed_result [constrained_names ])
306+
307+ if unconstrained_names and include_transformed :
308+ idata ["unconstrained_posterior" ] = transformed_result [unconstrained_names ]
309+
310+ return idata
0 commit comments