@@ -174,6 +174,7 @@ def __init__(
174174 prior = None ,
175175 posterior_predictive = None ,
176176 log_likelihood = False ,
177+ log_prior = False ,
177178 predictions = None ,
178179 coords : Optional [CoordSpec ] = None ,
179180 dims : Optional [DimSpec ] = None ,
@@ -215,6 +216,7 @@ def __init__(
215216 self .prior = prior
216217 self .posterior_predictive = posterior_predictive
217218 self .log_likelihood = log_likelihood
219+ self .log_prior = log_prior
218220 self .predictions = predictions
219221
220222 if all (elem is None for elem in (trace , predictions , posterior_predictive , prior )):
@@ -446,6 +448,17 @@ def to_inference_data(self):
446448 sample_dims = self .sample_dims ,
447449 progressbar = False ,
448450 )
451+ if self .log_prior :
452+ from pymc .stats .log_density import compute_log_prior
453+
454+ idata = compute_log_prior (
455+ idata ,
456+ var_names = None if self .log_prior is True else self .log_prior ,
457+ extend_inferencedata = True ,
458+ model = self .model ,
459+ sample_dims = self .sample_dims ,
460+ progressbar = False ,
461+ )
449462 return idata
450463
451464
@@ -455,6 +468,7 @@ def to_inference_data(
455468 prior : Optional [Mapping [str , Any ]] = None ,
456469 posterior_predictive : Optional [Mapping [str , Any ]] = None ,
457470 log_likelihood : Union [bool , Iterable [str ]] = False ,
471+ log_prior : Union [bool , Iterable [str ]] = False ,
458472 coords : Optional [CoordSpec ] = None ,
459473 dims : Optional [DimSpec ] = None ,
460474 sample_dims : Optional [list ] = None ,
@@ -481,8 +495,11 @@ def to_inference_data(
481495 Dictionary with the variable names as keys, and values numpy arrays
482496 containing posterior predictive samples.
483497 log_likelihood : bool or array_like of str, optional
484- List of variables to calculate `log_likelihood`. Defaults to True which calculates
485- `log_likelihood` for all observed variables. If set to False, log_likelihood is skipped.
498+ List of variables to calculate `log_likelihood`. Defaults to False.
499+ If set to True, computes `log_likelihood` for all observed variables.
500+ log_prior : bool or array_like of str, optional
501+ List of variables to calculate `log_prior`. Defaults to False.
502+ If set to True, computes `log_prior` for all unobserved variables.
486503 coords : dict of {str: array-like}, optional
487504 Map of coordinate names to coordinate values
488505 dims : dict of {str: list of str}, optional
@@ -509,6 +526,7 @@ def to_inference_data(
509526 prior = prior ,
510527 posterior_predictive = posterior_predictive ,
511528 log_likelihood = log_likelihood ,
529+ log_prior = log_prior ,
512530 coords = coords ,
513531 dims = dims ,
514532 sample_dims = sample_dims ,
0 commit comments