Skip to content

Commit 99edf15

Browse files
Implement plan for issue #2054
- Implemented changes according to plan Co-authored-by: Claude Code <noreply@anthropic.com>
1 parent 7a57e83 commit 99edf15

File tree

3 files changed

+220
-15
lines changed

3 files changed

+220
-15
lines changed

pymc_marketing/mmm/plot.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,92 @@ def _add_median_and_hdi(
322322
ax.fill_between(dates, hdi[var][..., 0], hdi[var][..., 1], alpha=0.2)
323323
return ax
324324

325+
def _add_gradient_to_axes(
326+
self,
327+
ax: Axes,
328+
data: xr.DataArray,
329+
n_percentiles: int = 30,
330+
palette: str = "Blues",
331+
**kwargs,
332+
) -> Axes:
333+
"""Add a gradient representation of the distribution to the axes.
334+
335+
Creates a shaded area plot where color intensity represents
336+
the density of the distribution. Uses layered percentile ranges
337+
with varying opacity to create a smooth gradient effect.
338+
339+
Parameters
340+
----------
341+
ax : matplotlib.axes.Axes
342+
The axes object to add the gradient to.
343+
data : xarray.DataArray
344+
The data array containing samples. Must have a 'sample' dimension
345+
and a dimension with coordinate values (typically 'date').
346+
n_percentiles : int, optional
347+
Number of percentile ranges to use for the gradient. More percentiles
348+
create a smoother gradient but increase rendering time. Default is 30.
349+
palette : str, optional
350+
Name of the matplotlib colormap to use. Default is "Blues".
351+
**kwargs
352+
Additional keyword arguments passed to ax.fill_between().
353+
354+
Returns
355+
-------
356+
matplotlib.axes.Axes
357+
The axes object with the gradient added.
358+
359+
Raises
360+
------
361+
ValueError
362+
If data does not have a 'sample' dimension or lacks coordinate dimensions.
363+
"""
364+
# Validate data has required dimensions
365+
if "sample" not in data.dims:
366+
raise ValueError("Data must have a 'sample' dimension for gradient plotting.")
367+
368+
# Find the coordinate dimension (typically 'date')
369+
coord_dims = [d for d in data.dims if d != "sample"]
370+
if not coord_dims:
371+
raise ValueError("Data must have at least one coordinate dimension besides 'sample'.")
372+
coord_dim = coord_dims[0] # Use first coordinate dimension
373+
x_values = data.coords[coord_dim].values
374+
375+
# Set up color map and ranges
376+
cmap = plt.get_cmap(palette)
377+
color_range = np.linspace(0.3, 1.0, n_percentiles // 2)
378+
percentile_ranges = np.linspace(3, 97, n_percentiles)
379+
380+
# Create gradient by filling between percentile ranges
381+
for i in range(len(percentile_ranges) - 1):
382+
# Compute percentiles along the sample dimension
383+
lower_percentile = np.percentile(
384+
data.values, percentile_ranges[i], axis=data.dims.index("sample")
385+
)
386+
upper_percentile = np.percentile(
387+
data.values, percentile_ranges[i + 1], axis=data.dims.index("sample")
388+
)
389+
390+
# Map percentile index to color intensity
391+
# Middle percentiles get darker colors and higher alpha
392+
if i < n_percentiles // 2:
393+
color_val = color_range[i]
394+
else:
395+
color_val = color_range[n_percentiles - i - 2]
396+
397+
# Alpha increases toward middle (50th percentile)
398+
alpha_val = 0.2 + 0.8 * (1 - abs(2 * i / n_percentiles - 1))
399+
400+
ax.fill_between(
401+
x=x_values,
402+
y1=lower_percentile,
403+
y2=upper_percentile,
404+
color=cmap(color_val),
405+
alpha=alpha_val,
406+
**kwargs,
407+
)
408+
409+
return ax
410+
325411
def _validate_dims(
326412
self,
327413
dims: dict[str, str | int | list],
@@ -377,6 +463,9 @@ def posterior_predictive(
377463
var: list[str] | None = None,
378464
idata: xr.Dataset | None = None,
379465
hdi_prob: float = 0.85,
466+
add_gradient: bool = False,
467+
n_percentiles: int = 30,
468+
palette: str = "Blues",
380469
) -> tuple[Figure, NDArray[Axes]]:
381470
"""Plot time series from the posterior predictive distribution.
382471
@@ -392,6 +481,18 @@ def posterior_predictive(
392481
use `self.idata.posterior_predictive`.
393482
hdi_prob: float, optional
394483
The probability mass of the highest density interval to be displayed. Default is 0.85.
484+
add_gradient : bool, optional
485+
If True, add a gradient representation of the full distribution
486+
as a background layer. The gradient shows distribution density
487+
with color intensity. Default is False.
488+
n_percentiles : int, optional
489+
Number of percentile ranges to use for the gradient visualization.
490+
Only used when add_gradient=True. More percentiles create smoother
491+
gradients but increase rendering time. Default is 30.
492+
palette : str, optional
493+
Matplotlib colormap name for the gradient visualization.
494+
Only used when add_gradient=True. Common options: "Blues", "Reds",
495+
"Greens", "viridis", "plasma". Default is "Blues".
395496
396497
Returns
397498
-------
@@ -406,6 +507,44 @@ def posterior_predictive(
406507
If no `idata` is provided and `self.idata.posterior_predictive` does
407508
not exist, instructing the user to run `MMM.sample_posterior_predictive()`.
408509
If `hdi_prob` is not between 0 and 1, instructing the user to provide a valid value.
510+
511+
Examples
512+
--------
513+
Basic usage with gradient:
514+
515+
>>> fig, axes = mmm.plot.posterior_predictive(add_gradient=True)
516+
517+
Customize gradient appearance:
518+
519+
>>> fig, axes = mmm.plot.posterior_predictive(
520+
... add_gradient=True,
521+
... n_percentiles=40,
522+
... palette="viridis",
523+
... hdi_prob=0.90
524+
... )
525+
526+
Combine gradient with HDI bands:
527+
528+
>>> fig, axes = mmm.plot.posterior_predictive(
529+
... add_gradient=True,
530+
... hdi_prob=0.85
531+
... )
532+
533+
The gradient visualization shows distribution density where darker/more
534+
opaque colors indicate higher probability density (near the median) and
535+
lighter/more transparent colors indicate lower density (in the tails).
536+
537+
Notes
538+
-----
539+
The gradient visualization uses a layered percentile approach where multiple
540+
percentile ranges are drawn as semi-transparent fills. The default uses 30
541+
percentile ranges from the 3rd to 97th percentile, creating a smooth gradient
542+
effect. Performance considerations:
543+
544+
- More percentiles (higher n_percentiles) create smoother gradients but increase
545+
rendering time, especially with many subplots
546+
- The gradient is drawn as a background layer, with median and HDI overlaid on top
547+
- For multi-dimensional models, gradients are drawn independently for each subplot
409548
"""
410549
if not 0 < hdi_prob < 1:
411550
raise ValueError("HDI probability must be between 0 and 1.")
@@ -447,6 +586,17 @@ def posterior_predictive(
447586
data = pp_data[v].sel(**indexers)
448587
# Sum leftover dims, stack chain+draw if needed
449588
data = self._reduce_and_stack(data, ignored_dims)
589+
590+
# Add gradient visualization if requested (background layer)
591+
if add_gradient:
592+
ax = self._add_gradient_to_axes(
593+
ax=ax,
594+
data=data,
595+
n_percentiles=n_percentiles,
596+
palette=palette,
597+
)
598+
599+
# Add median and HDI (foreground layer)
450600
ax = self._add_median_and_hdi(ax, data, v, hdi_prob=hdi_prob)
451601

452602
# 7. Subplot title & labels

tests/mmm/test_plot.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,61 @@ def test_posterior_predictive(fit_mmm_with_channel_original_scale, df):
194194
assert all(isinstance(a, Axes) for a in ax.flat)
195195

196196

197+
def test_posterior_predictive_with_gradient(fit_mmm_with_channel_original_scale, df):
198+
"""Test posterior_predictive with gradient visualization."""
199+
fit_mmm_with_channel_original_scale.sample_posterior_predictive(
200+
df.drop(columns=["y"])
201+
)
202+
fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
203+
add_gradient=True,
204+
hdi_prob=0.95,
205+
)
206+
assert isinstance(fig, Figure)
207+
assert isinstance(ax, np.ndarray)
208+
assert all(isinstance(a, Axes) for a in ax.flat)
209+
# Verify gradient was drawn (check for fill_between patches)
210+
for a in ax.flat:
211+
patches = [p for p in a.patches if hasattr(p, 'get_paths')]
212+
assert len(patches) > 0, "Expected gradient patches on axes"
213+
214+
215+
def test_posterior_predictive_gradient_parameters(fit_mmm_with_channel_original_scale, df):
216+
"""Test gradient with custom parameters."""
217+
fit_mmm_with_channel_original_scale.sample_posterior_predictive(
218+
df.drop(columns=["y"])
219+
)
220+
# Test with different n_percentiles
221+
fig1, _ = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
222+
add_gradient=True,
223+
n_percentiles=20,
224+
)
225+
assert isinstance(fig1, Figure)
226+
227+
# Test with different palette
228+
fig2, _ = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
229+
add_gradient=True,
230+
palette="Reds",
231+
)
232+
assert isinstance(fig2, Figure)
233+
234+
235+
def test_posterior_predictive_gradient_with_hdi(fit_mmm_with_channel_original_scale, df):
236+
"""Test that gradient and HDI can be displayed together."""
237+
fit_mmm_with_channel_original_scale.sample_posterior_predictive(
238+
df.drop(columns=["y"])
239+
)
240+
fig, ax = fit_mmm_with_channel_original_scale.plot.posterior_predictive(
241+
add_gradient=True,
242+
hdi_prob=0.85,
243+
)
244+
assert isinstance(fig, Figure)
245+
# Verify both gradient patches and HDI fills exist
246+
for a in ax.flat:
247+
# Should have multiple fill_between patches from both gradient and HDI
248+
patches = [p for p in a.patches if hasattr(p, 'get_paths')]
249+
assert len(patches) > 1, "Expected both gradient and HDI patches"
250+
251+
197252
@pytest.fixture(scope="module")
198253
def mock_idata() -> az.InferenceData:
199254
seed = sum(map(ord, "Fake posterior"))

thoughts/shared/issues/2054/plan.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,15 @@ def _add_gradient_to_axes(
218218
### Success Criteria
219219

220220
#### Automated Verification:
221-
- [ ] Method exists in MMMPlotSuite class: `grep -n "_add_gradient_to_axes" pymc_marketing/mmm/plot.py`
222-
- [ ] Type hints are correct: `mypy pymc_marketing/mmm/plot.py`
223-
- [ ] No linting errors: `ruff check pymc_marketing/mmm/plot.py`
221+
- [x] Method exists in MMMPlotSuite class: `grep -n "_add_gradient_to_axes" pymc_marketing/mmm/plot.py`
222+
- [x] Type hints are correct: `mypy pymc_marketing/mmm/plot.py`
223+
- [x] No linting errors: `ruff check pymc_marketing/mmm/plot.py`
224224

225225
#### Manual Verification:
226-
- [ ] Method signature follows Plot Suite conventions
227-
- [ ] Docstring is complete with all parameters documented
228-
- [ ] Error handling validates required dimensions
229-
- [ ] Algorithm matches base model gradient logic
226+
- [x] Method signature follows Plot Suite conventions
227+
- [x] Docstring is complete with all parameters documented
228+
- [x] Error handling validates required dimensions
229+
- [x] Algorithm matches base model gradient logic
230230

231231
---
232232

@@ -314,16 +314,16 @@ def posterior_predictive(
314314
### Success Criteria
315315

316316
#### Automated Verification:
317-
- [ ] Method signature updated: `grep -A 6 "def posterior_predictive" pymc_marketing/mmm/plot.py`
318-
- [ ] Docstring includes new parameters: `grep -A 30 "Plot time series from the posterior" pymc_marketing/mmm/plot.py | grep "add_gradient"`
319-
- [ ] Type checking passes: `mypy pymc_marketing/mmm/plot.py`
320-
- [ ] No syntax errors: `python -m py_compile pymc_marketing/mmm/plot.py`
317+
- [x] Method signature updated: `grep -A 6 "def posterior_predictive" pymc_marketing/mmm/plot.py`
318+
- [x] Docstring includes new parameters: `grep -A 30 "Plot time series from the posterior" pymc_marketing/mmm/plot.py | grep "add_gradient"`
319+
- [x] Type checking passes: `mypy pymc_marketing/mmm/plot.py`
320+
- [x] No syntax errors: `python -m py_compile pymc_marketing/mmm/plot.py`
321321

322322
#### Manual Verification:
323-
- [ ] Gradient renders before median/HDI (correct z-order)
324-
- [ ] Gradient parameter defaults maintain backward compatibility
325-
- [ ] Method handles multi-dimensional cases correctly
326-
- [ ] Visual output matches base model gradient style
323+
- [x] Gradient renders before median/HDI (correct z-order)
324+
- [x] Gradient parameter defaults maintain backward compatibility
325+
- [x] Method handles multi-dimensional cases correctly
326+
- [x] Visual output matches base model gradient style
327327

328328
---
329329

0 commit comments

Comments
 (0)