Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions pymc_marketing/mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,96 @@ def _add_median_and_hdi(
ax.fill_between(dates, hdi[var][..., 0], hdi[var][..., 1], alpha=0.2)
return ax

def _add_gradient_to_axes(
self,
ax: Axes,
data: xr.DataArray,
n_percentiles: int = 30,
palette: str = "Blues",
**kwargs,
) -> Axes:
"""Add a gradient representation of the distribution to the axes.

Creates a shaded area plot where color intensity represents
the density of the distribution. Uses layered percentile ranges
with varying opacity to create a smooth gradient effect.

Parameters
----------
ax : matplotlib.axes.Axes
The axes object to add the gradient to.
data : xarray.DataArray
The data array containing samples. Must have a 'sample' dimension
and a dimension with coordinate values (typically 'date').
n_percentiles : int, optional
Number of percentile ranges to use for the gradient. More percentiles
create a smoother gradient but increase rendering time. Default is 30.
palette : str, optional
Name of the matplotlib colormap to use. Default is "Blues".
**kwargs
Additional keyword arguments passed to ax.fill_between().

Returns
-------
matplotlib.axes.Axes
The axes object with the gradient added.

Raises
------
ValueError
If data does not have a 'sample' dimension or lacks coordinate dimensions.
"""
# Validate data has required dimensions
if "sample" not in data.dims:
raise ValueError(
"Data must have a 'sample' dimension for gradient plotting."
)

# Find the coordinate dimension (typically 'date')
coord_dims = [d for d in data.dims if d != "sample"]
if not coord_dims:
raise ValueError(
"Data must have at least one coordinate dimension besides 'sample'."
)
coord_dim = coord_dims[0] # Use first coordinate dimension
x_values = data.coords[coord_dim].values

# Set up color map and ranges
cmap = plt.get_cmap(palette)
color_range = np.linspace(0.3, 1.0, n_percentiles // 2)
percentile_ranges = np.linspace(3, 97, n_percentiles)

# Create gradient by filling between percentile ranges
for i in range(len(percentile_ranges) - 1):
# Compute percentiles along the sample dimension
lower_percentile = np.percentile(
data.values, percentile_ranges[i], axis=data.dims.index("sample")
)
upper_percentile = np.percentile(
data.values, percentile_ranges[i + 1], axis=data.dims.index("sample")
)

# Map percentile index to color intensity
# Middle percentiles get darker colors and higher alpha
if i < n_percentiles // 2:
color_val = color_range[i]
else:
color_val = color_range[n_percentiles - i - 2]

# Alpha increases toward middle (50th percentile)
alpha_val = 0.2 + 0.8 * (1 - abs(2 * i / n_percentiles - 1))

ax.fill_between(
x=x_values,
y1=lower_percentile,
y2=upper_percentile,
color=cmap(color_val),
alpha=alpha_val,
**kwargs,
)

return ax

def _validate_dims(
self,
dims: dict[str, str | int | list],
Expand Down Expand Up @@ -377,6 +467,9 @@ def posterior_predictive(
var: list[str] | None = None,
idata: xr.Dataset | None = None,
hdi_prob: float = 0.85,
add_gradient: bool = False,
n_percentiles: int = 30,
palette: str = "Blues",
) -> tuple[Figure, NDArray[Axes]]:
"""Plot time series from the posterior predictive distribution.

Expand All @@ -392,6 +485,18 @@ def posterior_predictive(
use `self.idata.posterior_predictive`.
hdi_prob: float, optional
The probability mass of the highest density interval to be displayed. Default is 0.85.
add_gradient : bool, optional
If True, add a gradient representation of the full distribution
as a background layer. The gradient shows distribution density
with color intensity. Default is False.
n_percentiles : int, optional
Number of percentile ranges to use for the gradient visualization.
Only used when add_gradient=True. More percentiles create smoother
gradients but increase rendering time. Default is 30.
palette : str, optional
Matplotlib colormap name for the gradient visualization.
Only used when add_gradient=True. Common options: "Blues", "Reds",
"Greens", "viridis", "plasma". Default is "Blues".

Returns
-------
Expand All @@ -406,6 +511,38 @@ def posterior_predictive(
If no `idata` is provided and `self.idata.posterior_predictive` does
not exist, instructing the user to run `MMM.sample_posterior_predictive()`.
If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value.

Examples
--------
Basic usage with gradient:

>>> fig, axes = mmm.plot.posterior_predictive(add_gradient=True)

Customize gradient appearance:

>>> fig, axes = mmm.plot.posterior_predictive(
... add_gradient=True, n_percentiles=40, palette="viridis", hdi_prob=0.90
... )

Combine gradient with HDI bands:

>>> fig, axes = mmm.plot.posterior_predictive(add_gradient=True, hdi_prob=0.85)

The gradient visualization shows distribution density where darker/more
opaque colors indicate higher probability density (near the median) and
lighter/more transparent colors indicate lower density (in the tails).

Notes
-----
The gradient visualization uses a layered percentile approach where multiple
percentile ranges are drawn as semi-transparent fills. The default uses 30
percentile ranges from the 3rd to 97th percentile, creating a smooth gradient
effect. Performance considerations:

- More percentiles (higher n_percentiles) create smoother gradients but increase
rendering time, especially with many subplots
- The gradient is drawn as a background layer, with median and HDI overlaid on top
- For multi-dimensional models, gradients are drawn independently for each subplot
"""
if not 0 < hdi_prob < 1:
raise ValueError("HDI probability must be between 0 and 1.")
Expand Down Expand Up @@ -447,6 +584,17 @@ def posterior_predictive(
data = pp_data[v].sel(**indexers)
# Sum leftover dims, stack chain+draw if needed
data = self._reduce_and_stack(data, ignored_dims)

# Add gradient visualization if requested (background layer)
if add_gradient:
ax = self._add_gradient_to_axes(
ax=ax,
data=data,
n_percentiles=n_percentiles,
palette=palette,
)

# Add median and HDI (foreground layer)
ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob)

# 7. Subplot title & labels
Expand Down
59 changes: 59 additions & 0 deletions tests/mmm/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,65 @@ def test_posterior_predictive(fit_mmm_with_channel_original_scale, df):
assert all(isinstance(a, Axes) for a in ax.flat)


def test_posterior_predictive_with_gradient(fit_mmm_with_channel_original_scale, df):
"""Test posterior_predictive with gradient visualization."""
fit_mmm_with_channel_original_scale.sample_posterior_predictive(
df.drop(columns=["y"])
)
fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
add_gradient=True,
hdi_prob=0.95,
)
assert isinstance(fig, Figure)
assert isinstance(ax, np.ndarray)
assert all(isinstance(a, Axes) for a in ax.flat)
# Verify gradient was drawn (check for fill_between patches)
for a in ax.flat:
patches = [p for p in a.patches if hasattr(p, "get_paths")]
assert len(patches) > 0, "Expected gradient patches on axes"


def test_posterior_predictive_gradient_parameters(
fit_mmm_with_channel_original_scale, df
):
"""Test gradient with custom parameters."""
fit_mmm_with_channel_original_scale.sample_posterior_predictive(
df.drop(columns=["y"])
)
# Test with different n_percentiles
fig1, _ = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
add_gradient=True,
n_percentiles=20,
)
assert isinstance(fig1, Figure)

# Test with different palette
fig2, _ = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
add_gradient=True,
palette="Reds",
)
assert isinstance(fig2, Figure)


def test_posterior_predictive_gradient_with_hdi(
fit_mmm_with_channel_original_scale, df
):
"""Test that gradient and HDI can be displayed together."""
fit_mmm_with_channel_original_scale.sample_posterior_predictive(
df.drop(columns=["y"])
)
fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
add_gradient=True,
hdi_prob=0.85,
)
assert isinstance(fig, Figure)
# Verify both gradient patches and HDI fills exist
for a in ax.flat:
# Should have multiple fill_between patches from both gradient and HDI
patches = [p for p in a.patches if hasattr(p, "get_paths")]
assert len(patches) > 1, "Expected both gradient and HDI patches"


@pytest.fixture(scope="module")
def mock_idata() -> az.InferenceData:
seed = sum(map(ord, "Fake posterior"))
Expand Down
Loading
Loading