@@ -322,6 +322,96 @@ def _add_median_and_hdi(
322322 ax .fill_between (dates , hdi [var ][..., 0 ], hdi [var ][..., 1 ], alpha = 0.2 )
323323 return ax
324324
325+ def _add_gradient_to_axes (
326+ self ,
327+ ax : Axes ,
328+ data : xr .DataArray ,
329+ n_percentiles : int = 30 ,
330+ palette : str = "Blues" ,
331+ ** kwargs ,
332+ ) -> Axes :
333+ """Add a gradient representation of the distribution to the axes.
334+
335+ Creates a shaded area plot where color intensity represents
336+ the density of the distribution. Uses layered percentile ranges
337+ with varying opacity to create a smooth gradient effect.
338+
339+ Parameters
340+ ----------
341+ ax : matplotlib.axes.Axes
342+ The axes object to add the gradient to.
343+ data : xarray.DataArray
344+ The data array containing samples. Must have a 'sample' dimension
345+ and a dimension with coordinate values (typically 'date').
346+ n_percentiles : int, optional
347+ Number of percentile ranges to use for the gradient. More percentiles
348+ create a smoother gradient but increase rendering time. Default is 30.
349+ palette : str, optional
350+ Name of the matplotlib colormap to use. Default is "Blues".
351+ **kwargs
352+ Additional keyword arguments passed to ax.fill_between().
353+
354+ Returns
355+ -------
356+ matplotlib.axes.Axes
357+ The axes object with the gradient added.
358+
359+ Raises
360+ ------
361+ ValueError
362+ If data does not have a 'sample' dimension or lacks coordinate dimensions.
363+ """
364+ # Validate data has required dimensions
365+ if "sample" not in data .dims :
366+ raise ValueError (
367+ "Data must have a 'sample' dimension for gradient plotting."
368+ )
369+
370+ # Find the coordinate dimension (typically 'date')
371+ coord_dims = [d for d in data .dims if d != "sample" ]
372+ if not coord_dims :
373+ raise ValueError (
374+ "Data must have at least one coordinate dimension besides 'sample'."
375+ )
376+ coord_dim = coord_dims [0 ] # Use first coordinate dimension
377+ x_values = data .coords [coord_dim ].values
378+
379+ # Set up color map and ranges
380+ cmap = plt .get_cmap (palette )
381+ color_range = np .linspace (0.3 , 1.0 , n_percentiles // 2 )
382+ percentile_ranges = np .linspace (3 , 97 , n_percentiles )
383+
384+ # Create gradient by filling between percentile ranges
385+ for i in range (len (percentile_ranges ) - 1 ):
386+ # Compute percentiles along the sample dimension
387+ lower_percentile = np .percentile (
388+ data .values , percentile_ranges [i ], axis = data .dims .index ("sample" )
389+ )
390+ upper_percentile = np .percentile (
391+ data .values , percentile_ranges [i + 1 ], axis = data .dims .index ("sample" )
392+ )
393+
394+ # Map percentile index to color intensity
395+ # Middle percentiles get darker colors and higher alpha
396+ if i < n_percentiles // 2 :
397+ color_val = color_range [i ]
398+ else :
399+ color_val = color_range [n_percentiles - i - 2 ]
400+
401+ # Alpha increases toward middle (50th percentile)
402+ alpha_val = 0.2 + 0.8 * (1 - abs (2 * i / n_percentiles - 1 ))
403+
404+ ax .fill_between (
405+ x = x_values ,
406+ y1 = lower_percentile ,
407+ y2 = upper_percentile ,
408+ color = cmap (color_val ),
409+ alpha = alpha_val ,
410+ ** kwargs ,
411+ )
412+
413+ return ax
414+
325415 def _validate_dims (
326416 self ,
327417 dims : dict [str , str | int | list ],
@@ -377,6 +467,9 @@ def posterior_predictive(
377467 var : list [str ] | None = None ,
378468 idata : xr .Dataset | None = None ,
379469 hdi_prob : float = 0.85 ,
470+ add_gradient : bool = False ,
471+ n_percentiles : int = 30 ,
472+ palette : str = "Blues" ,
380473 ) -> tuple [Figure , NDArray [Axes ]]:
381474 """Plot time series from the posterior predictive distribution.
382475
@@ -392,6 +485,18 @@ def posterior_predictive(
392485 use `self.idata.posterior_predictive`.
393486 hdi_prob: float, optional
394487 The probability mass of the highest density interval to be displayed. Default is 0.85.
488+ add_gradient : bool, optional
489+ If True, add a gradient representation of the full distribution
490+ as a background layer. The gradient shows distribution density
491+ with color intensity. Default is False.
492+ n_percentiles : int, optional
493+ Number of percentile ranges to use for the gradient visualization.
494+ Only used when add_gradient=True. More percentiles create smoother
495+ gradients but increase rendering time. Default is 30.
496+ palette : str, optional
497+ Matplotlib colormap name for the gradient visualization.
498+ Only used when add_gradient=True. Common options: "Blues", "Reds",
499+ "Greens", "viridis", "plasma". Default is "Blues".
395500
396501 Returns
397502 -------
@@ -406,6 +511,38 @@ def posterior_predictive(
406511 If no `idata` is provided and `self.idata.posterior_predictive` does
407512 not exist, instructing the user to run `MMM.sample_posterior_predictive()`.
408513 If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value.
514+
515+ Examples
516+ --------
517+ Basic usage with gradient:
518+
519+ >>> fig, axes = mmm.plot.posterior_predictive(add_gradient=True)
520+
521+ Customize gradient appearance:
522+
523+ >>> fig, axes = mmm.plot.posterior_predictive(
524+ ... add_gradient=True, n_percentiles=40, palette="viridis", hdi_prob=0.90
525+ ... )
526+
527+ Combine gradient with HDI bands:
528+
529+ >>> fig, axes = mmm.plot.posterior_predictive(add_gradient=True, hdi_prob=0.85)
530+
531+ The gradient visualization shows distribution density where darker/more
532+ opaque colors indicate higher probability density (near the median) and
533+ lighter/more transparent colors indicate lower density (in the tails).
534+
535+ Notes
536+ -----
537+ The gradient visualization uses a layered percentile approach where multiple
538+ percentile ranges are drawn as semi-transparent fills. The default uses 30
539+ percentile ranges from the 3rd to 97th percentile, creating a smooth gradient
540+ effect. Performance considerations:
541+
542+ - More percentiles (higher n_percentiles) create smoother gradients but increase
543+ rendering time, especially with many subplots
544+ - The gradient is drawn as a background layer, with median and HDI overlaid on top
545+ - For multi-dimensional models, gradients are drawn independently for each subplot
409546 """
410547 if not 0 < hdi_prob < 1 :
411548 raise ValueError ("HDI probability must be between 0 and 1." )
@@ -447,6 +584,17 @@ def posterior_predictive(
447584 data = pp_data [v ].sel (** indexers )
448585 # Sum leftover dims, stack chain+draw if needed
449586 data = self ._reduce_and_stack (data , ignored_dims )
587+
588+ # Add gradient visualization if requested (background layer)
589+ if add_gradient :
590+ ax = self ._add_gradient_to_axes (
591+ ax = ax ,
592+ data = data ,
593+ n_percentiles = n_percentiles ,
594+ palette = palette ,
595+ )
596+
597+ # Add median and HDI (foreground layer)
450598 ax = self ._add_median_and_hdi (ax , data , v , hdi_prob = hdi_prob )
451599
452600 # 7. Subplot title & labels
0 commit comments