Skip to content

Commit 6c1bb08

Browse files
clsandovalclaude
andcommitted
Add research for issue #2054
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent f1b929f commit 6c1bb08

File tree

4 files changed

+1400
-0
lines changed

4 files changed

+1400
-0
lines changed

pymc_marketing/mmm/plot.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,96 @@ def _add_median_and_hdi(
322322
ax.fill_between(dates, hdi[var][..., 0], hdi[var][..., 1], alpha=0.2)
323323
return ax
324324

325+
def _add_gradient_to_axes(
326+
self,
327+
ax: Axes,
328+
data: xr.DataArray,
329+
n_percentiles: int = 30,
330+
palette: str = "Blues",
331+
**kwargs,
332+
) -> Axes:
333+
"""Add a gradient representation of the distribution to the axes.
334+
335+
Creates a shaded area plot where color intensity represents
336+
the density of the distribution. Uses layered percentile ranges
337+
with varying opacity to create a smooth gradient effect.
338+
339+
Parameters
340+
----------
341+
ax : matplotlib.axes.Axes
342+
The axes object to add the gradient to.
343+
data : xarray.DataArray
344+
The data array containing samples. Must have a 'sample' dimension
345+
and a dimension with coordinate values (typically 'date').
346+
n_percentiles : int, optional
347+
Number of percentile ranges to use for the gradient. More percentiles
348+
create a smoother gradient but increase rendering time. Default is 30.
349+
palette : str, optional
350+
Name of the matplotlib colormap to use. Default is "Blues".
351+
**kwargs
352+
Additional keyword arguments passed to ax.fill_between().
353+
354+
Returns
355+
-------
356+
matplotlib.axes.Axes
357+
The axes object with the gradient added.
358+
359+
Raises
360+
------
361+
ValueError
362+
If data does not have a 'sample' dimension or lacks coordinate dimensions.
363+
"""
364+
# Validate data has required dimensions
365+
if "sample" not in data.dims:
366+
raise ValueError(
367+
"Data must have a 'sample' dimension for gradient plotting."
368+
)
369+
370+
# Find the coordinate dimension (typically 'date')
371+
coord_dims = [d for d in data.dims if d != "sample"]
372+
if not coord_dims:
373+
raise ValueError(
374+
"Data must have at least one coordinate dimension besides 'sample'."
375+
)
376+
coord_dim = coord_dims[0] # Use first coordinate dimension
377+
x_values = data.coords[coord_dim].values
378+
379+
# Set up color map and ranges
380+
cmap = plt.get_cmap(palette)
381+
color_range = np.linspace(0.3, 1.0, n_percentiles // 2)
382+
percentile_ranges = np.linspace(3, 97, n_percentiles)
383+
384+
# Create gradient by filling between percentile ranges
385+
for i in range(len(percentile_ranges) - 1):
386+
# Compute percentiles along the sample dimension
387+
lower_percentile = np.percentile(
388+
data.values, percentile_ranges[i], axis=data.dims.index("sample")
389+
)
390+
upper_percentile = np.percentile(
391+
data.values, percentile_ranges[i + 1], axis=data.dims.index("sample")
392+
)
393+
394+
# Map percentile index to color intensity
395+
# Middle percentiles get darker colors and higher alpha
396+
if i < n_percentiles // 2:
397+
color_val = color_range[i]
398+
else:
399+
color_val = color_range[n_percentiles - i - 2]
400+
401+
# Alpha increases toward middle (50th percentile)
402+
alpha_val = 0.2 + 0.8 * (1 - abs(2 * i / n_percentiles - 1))
403+
404+
ax.fill_between(
405+
x=x_values,
406+
y1=lower_percentile,
407+
y2=upper_percentile,
408+
color=cmap(color_val),
409+
alpha=alpha_val,
410+
**kwargs,
411+
)
412+
413+
return ax
414+
325415
def _validate_dims(
326416
self,
327417
dims: dict[str, str | int | list],
@@ -377,6 +467,9 @@ def posterior_predictive(
377467
var: list[str] | None = None,
378468
idata: xr.Dataset | None = None,
379469
hdi_prob: float = 0.85,
470+
add_gradient: bool = False,
471+
n_percentiles: int = 30,
472+
palette: str = "Blues",
380473
) -> tuple[Figure, NDArray[Axes]]:
381474
"""Plot time series from the posterior predictive distribution.
382475
@@ -392,6 +485,18 @@ def posterior_predictive(
392485
use `self.idata.posterior_predictive`.
393486
hdi_prob: float, optional
394487
The probability mass of the highest density interval to be displayed. Default is 0.85.
488+
add_gradient : bool, optional
489+
If True, add a gradient representation of the full distribution
490+
as a background layer. The gradient shows distribution density
491+
with color intensity. Default is False.
492+
n_percentiles : int, optional
493+
Number of percentile ranges to use for the gradient visualization.
494+
Only used when add_gradient=True. More percentiles create smoother
495+
gradients but increase rendering time. Default is 30.
496+
palette : str, optional
497+
Matplotlib colormap name for the gradient visualization.
498+
Only used when add_gradient=True. Common options: "Blues", "Reds",
499+
"Greens", "viridis", "plasma". Default is "Blues".
395500
396501
Returns
397502
-------
@@ -406,6 +511,38 @@ def posterior_predictive(
406511
If no `idata` is provided and `self.idata.posterior_predictive` does
407512
not exist, instructing the user to run `MMM.sample_posterior_predictive()`.
408513
If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value.
514+
515+
Examples
516+
--------
517+
Basic usage with gradient:
518+
519+
>>> fig, axes = mmm.plot.posterior_predictive(add_gradient=True)
520+
521+
Customize gradient appearance:
522+
523+
>>> fig, axes = mmm.plot.posterior_predictive(
524+
... add_gradient=True, n_percentiles=40, palette="viridis", hdi_prob=0.90
525+
... )
526+
527+
Combine gradient with HDI bands:
528+
529+
>>> fig, axes = mmm.plot.posterior_predictive(add_gradient=True, hdi_prob=0.85)
530+
531+
The gradient visualization shows distribution density where darker/more
532+
opaque colors indicate higher probability density (near the median) and
533+
lighter/more transparent colors indicate lower density (in the tails).
534+
535+
Notes
536+
-----
537+
The gradient visualization uses a layered percentile approach where multiple
538+
percentile ranges are drawn as semi-transparent fills. The default uses 30
539+
percentile ranges from the 3rd to 97th percentile, creating a smooth gradient
540+
effect. Performance considerations:
541+
542+
- More percentiles (higher n_percentiles) create smoother gradients but increase
543+
rendering time, especially with many subplots
544+
- The gradient is drawn as a background layer, with median and HDI overlaid on top
545+
- For multi-dimensional models, gradients are drawn independently for each subplot
409546
"""
410547
if not 0 < hdi_prob < 1:
411548
raise ValueError("HDI probability must be between 0 and 1.")
@@ -447,6 +584,17 @@ def posterior_predictive(
447584
data = pp_data[v].sel(**indexers)
448585
# Sum leftover dims, stack chain+draw if needed
449586
data = self._reduce_and_stack(data, ignored_dims)
587+
588+
# Add gradient visualization if requested (background layer)
589+
if add_gradient:
590+
ax = self._add_gradient_to_axes(
591+
ax=ax,
592+
data=data,
593+
n_percentiles=n_percentiles,
594+
palette=palette,
595+
)
596+
597+
# Add median and HDI (foreground layer)
450598
ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob)
451599

452600
# 7. Subplot title & labels

tests/mmm/test_plot.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,65 @@ def test_posterior_predictive(fit_mmm_with_channel_original_scale, df):
194194
assert all(isinstance(a, Axes) for a in ax.flat)
195195

196196

197+
def test_posterior_predictive_with_gradient(fit_mmm_with_channel_original_scale, df):
198+
"""Test posterior_predictive with gradient visualization."""
199+
fit_mmm_with_channel_original_scale.sample_posterior_predictive(
200+
df.drop(columns=["y"])
201+
)
202+
fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
203+
add_gradient=True,
204+
hdi_prob=0.95,
205+
)
206+
assert isinstance(fig, Figure)
207+
assert isinstance(ax, np.ndarray)
208+
assert all(isinstance(a, Axes) for a in ax.flat)
209+
# Verify gradient was drawn (check for fill_between patches)
210+
for a in ax.flat:
211+
patches = [p for p in a.patches if hasattr(p, "get_paths")]
212+
assert len(patches) > 0, "Expected gradient patches on axes"
213+
214+
215+
def test_posterior_predictive_gradient_parameters(
216+
fit_mmm_with_channel_original_scale, df
217+
):
218+
"""Test gradient with custom parameters."""
219+
fit_mmm_with_channel_original_scale.sample_posterior_predictive(
220+
df.drop(columns=["y"])
221+
)
222+
# Test with different n_percentiles
223+
fig1, _ = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
224+
add_gradient=True,
225+
n_percentiles=20,
226+
)
227+
assert isinstance(fig1, Figure)
228+
229+
# Test with different palette
230+
fig2, _ = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
231+
add_gradient=True,
232+
palette="Reds",
233+
)
234+
assert isinstance(fig2, Figure)
235+
236+
237+
def test_posterior_predictive_gradient_with_hdi(
238+
fit_mmm_with_channel_original_scale, df
239+
):
240+
"""Test that gradient and HDI can be displayed together."""
241+
fit_mmm_with_channel_original_scale.sample_posterior_predictive(
242+
df.drop(columns=["y"])
243+
)
244+
fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
245+
add_gradient=True,
246+
hdi_prob=0.85,
247+
)
248+
assert isinstance(fig, Figure)
249+
# Verify both gradient patches and HDI fills exist
250+
for a in ax.flat:
251+
# Should have multiple fill_between patches from both gradient and HDI
252+
patches = [p for p in a.patches if hasattr(p, "get_paths")]
253+
assert len(patches) > 1, "Expected both gradient and HDI patches"
254+
255+
197256
@pytest.fixture(scope="module")
198257
def mock_idata() -> az.InferenceData:
199258
seed = sum(map(ord, "Fake posterior"))

0 commit comments

Comments
 (0)