@@ -362,7 +362,9 @@ def plot_prior_predictive(self, **plt_kwargs: Any) -> plt.Figure:
362362 def plot_posterior_predictive (
363363 self ,
364364 original_scale : bool = False ,
365+ add_hdi : bool = True ,
365366 add_mean : bool = True ,
367+ add_gradient : bool = False ,
366368 ax : plt .Axes = None ,
367369 ** plt_kwargs : Any ,
368370 ) -> plt .Figure :
@@ -404,16 +406,22 @@ def plot_posterior_predictive(
404406 else :
405407 fig = ax .figure
406408
407- for hdi_prob , alpha in zip ((0.94 , 0.50 ), (0.2 , 0.4 ), strict = True ):
408- ax = self ._add_hdi_to_plot (
409- ax = ax , original_scale = original_scale , hdi_prob = hdi_prob , alpha = alpha
410- )
409+ if add_hdi :
410+ for hdi_prob , alpha in zip ((0.94 , 0.50 ), (0.2 , 0.4 ), strict = True ):
411+ ax = self ._add_hdi_to_plot (
412+ ax = ax , original_scale = original_scale , hdi_prob = hdi_prob , alpha = alpha
413+ )
411414
412415 if add_mean :
413416 ax = self ._add_mean_to_plot (
414417 ax = ax , original_scale = original_scale , color = "red"
415418 )
416419
420+ if add_gradient :
421+ ax = self ._add_gradient_to_plot (
422+ ax = ax , original_scale = original_scale , n_percentiles = 30 , palette = "Blues"
423+ )
424+
417425 ax .plot (
418426 np .asarray (posterior_predictive_data .date ),
419427 target_to_plot ,
@@ -497,6 +505,78 @@ def _add_hdi_to_plot(
497505 )
498506 return ax
499507
508+ def _add_gradient_to_plot (
509+ self ,
510+ ax : plt .Axes ,
511+ original_scale : bool = False ,
512+ n_percentiles : int = 30 ,
513+ palette : str = "Blues" ,
514+ ** kwargs ,
515+ ) -> plt .Axes :
516+ """
517+ Add a gradient representation of the posterior predictive distribution to an existing plot.
518+
519+ This method creates a shaded area plot where the color intensity represents
520+ the density of the posterior predictive distribution.
521+
522+ Parameters
523+ ----------
524+ ax : plt.Axes
525+ The matplotlib axes object to add the gradient to.
526+ original_scale : bool, optional
527+ If True, use the original scale of the data. Default is False.
528+ n_percentiles : int, optional
529+ Number of percentile ranges to use for the gradient. Default is 30.
530+ palette : str, optional
531+ Color palette to use for the gradient. Default is "Blues".
532+ **kwargs
533+ Additional keyword arguments passed to ax.fill_between().
534+
535+ Returns
536+ -------
537+ plt.Axes
538+ The matplotlib axes object with the gradient added.
539+ """
540+ # Get posterior predictive data and flatten it
541+ posterior_predictive = self ._get_posterior_predictive_data (
542+ original_scale = original_scale
543+ )
544+ posterior_predictive_flattened = posterior_predictive .stack (
545+ sample = ("chain" , "draw" )
546+ ).to_dataarray ()
547+ dates = posterior_predictive .date .values
548+
549+ # Set up color map and ranges
550+ cmap = plt .get_cmap (palette )
551+ color_range = np .linspace (0.3 , 1.0 , n_percentiles // 2 )
552+ percentile_ranges = np .linspace (3 , 97 , n_percentiles )
553+
554+ # Create gradient by filling between percentile ranges
555+ for i in range (len (percentile_ranges ) - 1 ):
556+ lower_percentile = np .percentile (
557+ posterior_predictive_flattened , percentile_ranges [i ], axis = 2
558+ ).squeeze ()
559+ upper_percentile = np .percentile (
560+ posterior_predictive_flattened , percentile_ranges [i + 1 ], axis = 2
561+ ).squeeze ()
562+ if i < n_percentiles // 2 :
563+ color_val = color_range [i ]
564+ else :
565+ color_val = color_range [n_percentiles - i - 2 ]
566+ alpha_val = 0.2 + 0.8 * (
567+ 1 - abs (2 * i / n_percentiles - 1 )
568+ ) # Higher alpha in the middle
569+ ax .fill_between (
570+ x = dates ,
571+ y1 = lower_percentile ,
572+ y2 = upper_percentile ,
573+ color = cmap (color_val ),
574+ alpha = alpha_val ,
575+ ** kwargs ,
576+ )
577+
578+ return ax
579+
500580 def get_errors (self , original_scale : bool = False ) -> DataArray :
501581 """Get model errors posterior distribution.
502582
0 commit comments