@@ -405,17 +405,8 @@ def plot_posterior_predictive(
405405 fig = ax .figure
406406
407407 for hdi_prob , alpha in zip ((0.94 , 0.50 ), (0.2 , 0.4 ), strict = True ):
408- likelihood_hdi : DataArray = az .hdi (
409- ary = posterior_predictive_data , hdi_prob = hdi_prob
410- )[self .output_var ]
411-
412- ax .fill_between (
413- x = posterior_predictive_data .date ,
414- y1 = likelihood_hdi [:, 0 ],
415- y2 = likelihood_hdi [:, 1 ],
416- color = "C0" ,
417- alpha = alpha ,
418- label = f"{ hdi_prob :.0%} HDI" ,
408+ ax = self ._add_hdi_to_plot (
409+ ax = ax , original_scale = original_scale , hdi_prob = hdi_prob , alpha = alpha
419410 )
420411
421412 if add_mean :
@@ -477,6 +468,35 @@ def _add_mean_to_plot(
477468 )
478469 return ax
479470
471+ def _add_hdi_to_plot (
472+ self ,
473+ ax : plt .Axes ,
474+ original_scale : bool = False ,
475+ hdi_prob : float = 0.94 ,
476+ color : str = "C0" ,
477+ alpha : float = 0.2 ,
478+ ** kwargs ,
479+ ) -> plt .Axes :
480+ """Add HDI to existing plot."""
481+ posterior_predictive_data : Dataset = self ._get_posterior_predictive_data (
482+ original_scale = original_scale
483+ )
484+
485+ likelihood_hdi : DataArray = az .hdi (
486+ ary = posterior_predictive_data , hdi_prob = hdi_prob
487+ )[self .output_var ]
488+
489+ ax .fill_between (
490+ x = posterior_predictive_data .date ,
491+ y1 = likelihood_hdi [:, 0 ],
492+ y2 = likelihood_hdi [:, 1 ],
493+ color = color ,
494+ alpha = alpha ,
495+ label = f"{ hdi_prob :.0%} HDI" ,
496+ ** kwargs ,
497+ )
498+ return ax
499+
480500 def get_errors (self , original_scale : bool = False ) -> DataArray :
481501 """Get model errors posterior distribution.
482502
0 commit comments