@@ -322,6 +322,92 @@ def _add_median_and_hdi(
322322 ax .fill_between (dates , hdi [var ][..., 0 ], hdi [var ][..., 1 ], alpha = 0.2 )
323323 return ax
324324
325+ def _add_gradient_to_axes (
326+ self ,
327+ ax : Axes ,
328+ data : xr .DataArray ,
329+ n_percentiles : int = 30 ,
330+ palette : str = "Blues" ,
331+ ** kwargs ,
332+ ) -> Axes :
333+ """Add a gradient representation of the distribution to the axes.
334+
335+ Creates a shaded area plot where color intensity represents
336+ the density of the distribution. Uses layered percentile ranges
337+ with varying opacity to create a smooth gradient effect.
338+
339+ Parameters
340+ ----------
341+ ax : matplotlib.axes.Axes
342+ The axes object to add the gradient to.
343+ data : xarray.DataArray
344+ The data array containing samples. Must have a 'sample' dimension
345+ and a dimension with coordinate values (typically 'date').
346+ n_percentiles : int, optional
347+ Number of percentile ranges to use for the gradient. More percentiles
348+ create a smoother gradient but increase rendering time. Default is 30.
349+ palette : str, optional
350+ Name of the matplotlib colormap to use. Default is "Blues".
351+ **kwargs
352+ Additional keyword arguments passed to ax.fill_between().
353+
354+ Returns
355+ -------
356+ matplotlib.axes.Axes
357+ The axes object with the gradient added.
358+
359+ Raises
360+ ------
361+ ValueError
362+ If data does not have a 'sample' dimension or lacks coordinate dimensions.
363+ """
364+ # Validate data has required dimensions
365+ if "sample" not in data .dims :
366+ raise ValueError ("Data must have a 'sample' dimension for gradient plotting." )
367+
368+ # Find the coordinate dimension (typically 'date')
369+ coord_dims = [d for d in data .dims if d != "sample" ]
370+ if not coord_dims :
371+ raise ValueError ("Data must have at least one coordinate dimension besides 'sample'." )
372+ coord_dim = coord_dims [0 ] # Use first coordinate dimension
373+ x_values = data .coords [coord_dim ].values
374+
375+ # Set up color map and ranges
376+ cmap = plt .get_cmap (palette )
377+ color_range = np .linspace (0.3 , 1.0 , n_percentiles // 2 )
378+ percentile_ranges = np .linspace (3 , 97 , n_percentiles )
379+
380+ # Create gradient by filling between percentile ranges
381+ for i in range (len (percentile_ranges ) - 1 ):
382+ # Compute percentiles along the sample dimension
383+ lower_percentile = np .percentile (
384+ data .values , percentile_ranges [i ], axis = data .dims .index ("sample" )
385+ )
386+ upper_percentile = np .percentile (
387+ data .values , percentile_ranges [i + 1 ], axis = data .dims .index ("sample" )
388+ )
389+
390+ # Map percentile index to color intensity
391+ # Middle percentiles get darker colors and higher alpha
392+ if i < n_percentiles // 2 :
393+ color_val = color_range [i ]
394+ else :
395+ color_val = color_range [n_percentiles - i - 2 ]
396+
397+ # Alpha increases toward middle (50th percentile)
398+ alpha_val = 0.2 + 0.8 * (1 - abs (2 * i / n_percentiles - 1 ))
399+
400+ ax .fill_between (
401+ x = x_values ,
402+ y1 = lower_percentile ,
403+ y2 = upper_percentile ,
404+ color = cmap (color_val ),
405+ alpha = alpha_val ,
406+ ** kwargs ,
407+ )
408+
409+ return ax
410+
325411 def _validate_dims (
326412 self ,
327413 dims : dict [str , str | int | list ],
@@ -377,6 +463,9 @@ def posterior_predictive(
377463 var : list [str ] | None = None ,
378464 idata : xr .Dataset | None = None ,
379465 hdi_prob : float = 0.85 ,
466+ add_gradient : bool = False ,
467+ n_percentiles : int = 30 ,
468+ palette : str = "Blues" ,
380469 ) -> tuple [Figure , NDArray [Axes ]]:
381470 """Plot time series from the posterior predictive distribution.
382471
@@ -392,6 +481,18 @@ def posterior_predictive(
392481 use `self.idata.posterior_predictive`.
393482 hdi_prob: float, optional
394483 The probability mass of the highest density interval to be displayed. Default is 0.85.
484+ add_gradient : bool, optional
485+ If True, add a gradient representation of the full distribution
486+ as a background layer. The gradient shows distribution density
487+ with color intensity. Default is False.
488+ n_percentiles : int, optional
489+ Number of percentile ranges to use for the gradient visualization.
490+ Only used when add_gradient=True. More percentiles create smoother
491+ gradients but increase rendering time. Default is 30.
492+ palette : str, optional
493+ Matplotlib colormap name for the gradient visualization.
494+ Only used when add_gradient=True. Common options: "Blues", "Reds",
495+ "Greens", "viridis", "plasma". Default is "Blues".
395496
396497 Returns
397498 -------
@@ -406,6 +507,44 @@ def posterior_predictive(
406507 If no `idata` is provided and `self.idata.posterior_predictive` does
407508 not exist, instructing the user to run `MMM.sample_posterior_predictive()`.
408509 If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value.
510+
511+ Examples
512+ --------
513+ Basic usage with gradient:
514+
515+ >>> fig, axes = mmm.plot.posterior_predictive(add_gradient=True)
516+
517+ Customize gradient appearance:
518+
519+ >>> fig, axes = mmm.plot.posterior_predictive(
520+ ... add_gradient=True,
521+ ... n_percentiles=40,
522+ ... palette="viridis",
523+ ... hdi_prob=0.90
524+ ... )
525+
526+ Combine gradient with HDI bands:
527+
528+ >>> fig, axes = mmm.plot.posterior_predictive(
529+ ... add_gradient=True,
530+ ... hdi_prob=0.85
531+ ... )
532+
533+ The gradient visualization shows distribution density where darker/more
534+ opaque colors indicate higher probability density (near the median) and
535+ lighter/more transparent colors indicate lower density (in the tails).
536+
537+ Notes
538+ -----
539+ The gradient visualization uses a layered percentile approach where multiple
540+ percentile ranges are drawn as semi-transparent fills. The default uses 30
541+ percentile ranges from the 3rd to 97th percentile, creating a smooth gradient
542+ effect. Performance considerations:
543+
544+ - More percentiles (higher n_percentiles) create smoother gradients but increase
545+ rendering time, especially with many subplots
546+ - The gradient is drawn as a background layer, with median and HDI overlaid on top
547+ - For multi-dimensional models, gradients are drawn independently for each subplot
409548 """
410549 if not 0 < hdi_prob < 1 :
411550 raise ValueError ("HDI probability must be between 0 and 1." )
@@ -447,6 +586,17 @@ def posterior_predictive(
447586 data = pp_data [v ].sel (** indexers )
448587 # Sum leftover dims, stack chain+draw if needed
449588 data = self ._reduce_and_stack (data , ignored_dims )
589+
590+ # Add gradient visualization if requested (background layer)
591+ if add_gradient :
592+ ax = self ._add_gradient_to_axes (
593+ ax = ax ,
594+ data = data ,
595+ n_percentiles = n_percentiles ,
596+ palette = palette ,
597+ )
598+
599+ # Add median and HDI (foreground layer)
450600 ax = self ._add_median_and_hdi (ax , data , v , hdi_prob = hdi_prob )
451601
452602 # 7. Subplot title & labels
0 commit comments