@@ -116,11 +116,10 @@ def pathfinder_result_to_xarray(
116116 >>> with pm.Model() as model:
117117 ... x = pm.Normal("x", 0, 1)
118118 ... y = pm.Normal("y", x, 1, observed=2.0)
119- ...
120119 >>> # Assuming we have a PathfinderResult from a pathfinder run
121120 >>> ds = pathfinder_result_to_xarray(result, model=model)
122121 >>> print(ds.data_vars) # Shows lbfgs_niter, elbo_argmax, status info, etc.
123- >>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
122+ >>> print(ds.attrs) # Shows metadata like lbfgs_status, path_status
124123 """
125124 data_vars = {}
126125 coords = {}
@@ -214,9 +213,16 @@ def multipathfinder_result_to_xarray(
214213 >>> # Assuming we have a MultiPathfinderResult from multiple pathfinder runs
215214 >>> ds = multipathfinder_result_to_xarray(result, model=model)
216215 >>> print("All data:", ds.data_vars)
217- >>> print("Summary:", [k for k in ds.data_vars.keys() if not k.startswith(('paths/', 'config/', 'diagnostics/'))])
218- >>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith('paths/')])
219- >>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith('config/')])
216+ >>> print(
217+ ... "Summary:",
218+ ... [
219+ ... k
220+ ... for k in ds.data_vars.keys()
221+ ... if not k.startswith(("paths/", "config/", "diagnostics/"))
222+ ... ],
223+ ... )
224+ >>> print("Per-path:", [k for k in ds.data_vars.keys() if k.startswith("paths/")])
225+ >>> print("Config:", [k for k in ds.data_vars.keys() if k.startswith("config/")])
220226 """
221227 n_params = result .samples .shape [- 1 ] if result .samples is not None else None
222228 param_coords = get_param_coords (model , n_params ) if n_params is not None else None
@@ -477,13 +483,16 @@ def add_pathfinder_to_inference_data(
477483 >>> with pm.Model() as model:
478484 ... x = pm.Normal("x", 0, 1)
479485 ... idata = pmx.fit(method="pathfinder", model=model, add_pathfinder_groups=False)
480- ...
481486 >>> # Assuming we have pathfinder results
482487 >>> idata = add_pathfinder_to_inference_data(idata, results, model=model)
483488 >>> print(list(idata.groups())) # Will show ['posterior', 'pathfinder']
484489 >>> # Access nested data:
485- >>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('paths/')]) # Per-path data
486- >>> print([k for k in idata.pathfinder.data_vars.keys() if k.startswith('config/')]) # Config data
490+ >>> print(
491+ ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("paths/")]
492+ ... ) # Per-path data
493+ >>> print(
494+ ... [k for k in idata.pathfinder.data_vars.keys() if k.startswith("config/")]
495+ ... ) # Config data
487496 """
488497 # Detect if this is a multi-path result
489498 # Use isinstance() as primary check, but fall back to duck typing for compatibility
0 commit comments