Skip to content

Commit fcd8aed

Browse files
Add research for issue #2054
1 parent 9537b9a commit fcd8aed

File tree

1 file changed

+342
-0
lines changed

1 file changed

+342
-0
lines changed
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
---
2+
date: 2025-11-04T22:26:50Z
3+
researcher: Claude Code
4+
git_commit: 9537b9a08837a3c5dabcdee6244a0cd1c4688ea0
5+
branch: work-issue-2054
6+
repository: pymc-labs/pymc-marketing
7+
topic: "Add Gradient to plot posterior predictive in Plot Suite"
8+
tags: [research, codebase, plot-suite, gradient, visualization, mmm, issue-2054]
9+
status: complete
10+
last_updated: 2025-11-04
11+
last_updated_by: Claude Code
12+
issue_number: 2054
13+
---
14+
15+
# Research: Add Gradient to plot posterior predictive in Plot Suite
16+
17+
**Date**: 2025-11-04T22:26:50Z
18+
**Researcher**: Claude Code
19+
**Git Commit**: 9537b9a08837a3c5dabcdee6244a0cd1c4688ea0
20+
**Branch**: work-issue-2054
21+
**Repository**: pymc-labs/pymc-marketing
22+
**Issue**: #2054
23+
24+
## Research Question
25+
26+
How to add Gradient visualization functionality to the `plot_posterior_predictive` method in the Plot Suite to enable full migration from base model plotting methods?
27+
28+
## Summary
29+
30+
The gradient visualization feature currently exists in the **BaseValidateMMM** class but is **not available in the MMMPlotSuite**. This creates an incomplete migration path because:
31+
32+
1. **BaseValidateMMM.plot_posterior_predictive()** (`pymc_marketing/mmm/base.py:625`) supports `add_gradient` parameter
33+
2. **MMMPlotSuite.posterior_predictive()** (`pymc_marketing/mmm/plot.py:375`) does NOT support gradient visualization
34+
3. The gradient implementation in BaseValidateMMM uses `_add_gradient_to_plot()` (`pymc_marketing/mmm/base.py:362-433`)
35+
36+
To complete the migration, the gradient functionality needs to be added to the Plot Suite's `posterior_predictive()` method.
37+
38+
## Detailed Findings
39+
40+
### 1. Current Plot Suite Architecture
41+
42+
**Location**: `pymc_marketing/mmm/plot.py:187-1923`
43+
44+
The `MMMPlotSuite` class provides a comprehensive plotting API for Media Mix Models:
45+
46+
```python
47+
class MMMPlotSuite:
48+
"""Media Mix Model Plot Suite."""
49+
50+
def __init__(self, idata: xr.Dataset | az.InferenceData):
51+
self.idata = idata
52+
```
53+
54+
**Integration**: Exposed via property in `pymc_marketing/mmm/multidimensional.py:618-623`:
55+
56+
```python
57+
@property
58+
def plot(self) -> MMMPlotSuite:
59+
"""Use the MMMPlotSuite to plot the results."""
60+
return MMMPlotSuite(idata=self.idata)
61+
```
62+
63+
**Access Pattern**: Users call `mmm.plot.method_name()` on fitted models.
64+
65+
### 2. Posterior Predictive Plotting - Two APIs
66+
67+
#### API 1: MMMPlotSuite.posterior_predictive() (Plot Suite)
68+
**Location**: `pymc_marketing/mmm/plot.py:375-463`
69+
70+
**Current Signature**:
71+
```python
72+
def posterior_predictive(
73+
self,
74+
var: list[str] | None = None,
75+
idata: xr.Dataset | None = None,
76+
hdi_prob: float = 0.85,
77+
) -> tuple[Figure, NDArray[Axes]]:
78+
```
79+
80+
**Features**:
81+
- Multi-variable plotting
82+
- Multi-dimensional subplot support
83+
- HDI bands at configurable probability
84+
- Median line visualization
85+
- **Missing: Gradient visualization**
86+
87+
**Test Coverage**: `tests/mmm/test_plot.py:185`
88+
89+
#### API 2: BaseValidateMMM.plot_posterior_predictive() (Base Model)
90+
**Location**: `pymc_marketing/mmm/base.py:625-682`
91+
92+
**Current Signature**:
93+
```python
94+
def plot_posterior_predictive(
95+
self,
96+
original_scale: bool = False,
97+
hdi_list: list[float] | None = None,
98+
add_mean: bool = True,
99+
add_gradient: bool = False,
100+
ax: plt.Axes | None = None,
101+
**plt_kwargs,
102+
) -> plt.Figure:
103+
```
104+
105+
**Features**:
106+
- Scale transformation (`original_scale`)
107+
- Multiple HDI levels (`hdi_list`)
108+
- Mean line (`add_mean`)
109+
- **Gradient visualization (`add_gradient`)** ← This is what needs to be migrated
110+
- Custom axes support
111+
112+
**Test Coverage**: `tests/mmm/test_plotting.py:206-263` (extensive parametrized tests)
113+
114+
### 3. Gradient Implementation Details
115+
116+
**Core Implementation**: `pymc_marketing/mmm/base.py:362-433`
117+
118+
```python
119+
def _add_gradient_to_plot(
120+
self,
121+
ax: plt.Axes,
122+
group: Literal["prior_predictive", "posterior_predictive"],
123+
original_scale: bool = False,
124+
n_percentiles: int = 30,
125+
palette: str = "Blues",
126+
**kwargs,
127+
) -> plt.Axes:
128+
"""
129+
Add a gradient representation of the prior or posterior predictive distribution.
130+
131+
Creates a shaded area plot where color intensity represents
132+
the density of the posterior predictive distribution.
133+
"""
134+
```
135+
136+
**Algorithm**:
137+
1. Retrieves posterior_predictive data and flattens samples
138+
2. Computes percentile ranges (default: 30 ranges from 3rd to 97th percentile)
139+
3. Creates layered `fill_between()` calls with varying colors and alpha
140+
4. Middle percentiles use higher alpha (denser distribution)
141+
5. Outer percentiles use lower alpha (sparser distribution)
142+
6. Color mapping via matplotlib colormap (default "Blues")
143+
144+
**Visual Effect**: Creates a smooth gradient visualization showing full distribution density.
145+
146+
**Usage in Base Model** (`pymc_marketing/mmm/base.py:534-541`):
147+
```python
148+
if add_gradient:
149+
ax = self._add_gradient_to_plot(
150+
ax=ax,
151+
group=group,
152+
original_scale=original_scale,
153+
n_percentiles=30,
154+
palette="Blues",
155+
)
156+
```
157+
158+
### 4. Test Coverage for Gradient Feature
159+
160+
**Location**: `tests/mmm/test_plotting.py:206-263`
161+
162+
Tests include combinations of:
163+
- `add_gradient: True` with various other parameters
164+
- Prior predictive plots (lines 160, 181, 189, 197)
165+
- Posterior predictive plots (lines 219, 240, 248, 256)
166+
- Combinations with `add_mean`, `original_scale`, `hdi_list`
167+
168+
Example test cases:
169+
```python
170+
("plot_posterior_predictive", {"add_gradient": True}),
171+
("plot_posterior_predictive", {"add_gradient": True, "original_scale": True}),
172+
("plot_posterior_predictive", {"add_gradient": True, "add_mean": False}),
173+
```
174+
175+
### 5. Migration Context
176+
177+
**Current State**:
178+
- BaseValidateMMM methods: Full-featured but older API pattern
179+
- MMMPlotSuite: Modern API with better multi-dimensional support but missing gradient
180+
181+
**Migration Goal**: Enable users to get all functionality through the Plot Suite API:
182+
- Before: `mmm.plot_posterior_predictive(add_gradient=True)`
183+
- After: `mmm.plot.posterior_predictive(add_gradient=True)` or similar
184+
185+
**Blockers for Full Migration**:
186+
1. Gradient visualization not available in Plot Suite
187+
2. `original_scale` parameter not in Plot Suite
188+
3. Multiple HDI levels (`hdi_list`) not in Plot Suite (currently single `hdi_prob`)
189+
190+
## Code References
191+
192+
### Key Implementation Files
193+
- `pymc_marketing/mmm/plot.py:187` - MMMPlotSuite class
194+
- `pymc_marketing/mmm/plot.py:375` - MMMPlotSuite.posterior_predictive() method
195+
- `pymc_marketing/mmm/base.py:362` - _add_gradient_to_plot() implementation
196+
- `pymc_marketing/mmm/base.py:625` - BaseValidateMMM.plot_posterior_predictive()
197+
- `pymc_marketing/mmm/multidimensional.py:618` - Plot Suite property accessor
198+
199+
### Helper Methods (Reusable in Implementation)
200+
- `pymc_marketing/mmm/plot.py:200` - `_init_subplots()` - Subplot grid initialization
201+
- `pymc_marketing/mmm/plot.py:247` - `_get_additional_dim_combinations()` - Dimension handling
202+
- `pymc_marketing/mmm/plot.py:269` - `_reduce_and_stack()` - Data reduction
203+
- `pymc_marketing/mmm/plot.py:286` - `_get_posterior_predictive_data()` - Data retrieval
204+
- `pymc_marketing/mmm/plot.py:306` - `_add_median_and_hdi()` - Add median/HDI to plot
205+
206+
### Test Files
207+
- `tests/mmm/test_plot.py:185` - Basic posterior_predictive test
208+
- `tests/mmm/test_plotting.py:206-263` - Parametrized tests with gradient
209+
- `tests/mmm/test_base.py:358` - Error handling test
210+
211+
## Architecture Insights
212+
213+
### Design Patterns in Plot Suite
214+
215+
1. **Dimension-Aware Subplots**: Plot methods automatically create subplots for each combination of non-ignored dimensions
216+
2. **Helper Method Composition**: Complex plotting logic decomposed into reusable helpers
217+
3. **xarray Integration**: Heavy use of xarray for multi-dimensional data manipulation
218+
4. **Tuple Returns**: Methods return `(Figure, NDArray[Axes])` for flexibility
219+
5. **No State Mutation**: Plot Suite is stateless, only operates on InferenceData
220+
221+
### Gradient Implementation Pattern
222+
223+
The gradient visualization follows a **layered percentile** approach:
224+
- **Conceptual**: Stack many thin HDI bands with varying opacity
225+
- **Visual**: Creates smooth density gradient from sparse (edges) to dense (center)
226+
- **Technical**: Uses `np.percentile()` + `ax.fill_between()` in loop
227+
- **Customization**: Configurable via `n_percentiles` and `palette`
228+
229+
### Integration Approach for Plot Suite
230+
231+
**Recommended Pattern**: Add `add_gradient` parameter to `MMMPlotSuite.posterior_predictive()`
232+
233+
```python
234+
def posterior_predictive(
235+
self,
236+
var: list[str] | None = None,
237+
idata: xr.Dataset | None = None,
238+
hdi_prob: float = 0.85,
239+
add_gradient: bool = False, # NEW
240+
n_percentiles: int = 30, # NEW
241+
palette: str = "Blues", # NEW
242+
) -> tuple[Figure, NDArray[Axes]]:
243+
```
244+
245+
**Implementation Strategy**:
246+
1. Extract gradient logic from `BaseValidateMMM._add_gradient_to_plot()` into standalone function
247+
2. Adapt to work with xarray DataArrays (Plot Suite uses xarray, base uses Dataset)
248+
3. Integrate gradient plotting into dimension loop in `posterior_predictive()`
249+
4. Place gradient layer BEFORE median/HDI visualization (background layer)
250+
5. Add conditional logic: `if add_gradient:` section before `_add_median_and_hdi()`
251+
252+
**Key Adaptation**: The base model gradient method works with dates directly from Dataset, while Plot Suite works with dimension-sliced DataArrays. Need to handle this in the adaptation.
253+
254+
## Implementation Checklist
255+
256+
To add gradient to Plot Suite's `posterior_predictive()`:
257+
258+
1. **Create Helper Method**: Add `_add_gradient_to_axes()` in MMMPlotSuite
259+
- Adapt `BaseValidateMMM._add_gradient_to_plot()` logic
260+
- Accept xarray DataArray instead of Dataset
261+
- Work with generic dimensions (not just "date")
262+
263+
2. **Modify `posterior_predictive()` Method**:
264+
- Add parameters: `add_gradient`, `n_percentiles`, `palette`
265+
- Add gradient rendering in dimension loop (before median/HDI)
266+
- Ensure gradient is background layer (drawn first)
267+
268+
3. **Update Tests**:
269+
- Add test cases in `tests/mmm/test_plot.py`
270+
- Test with single dimension
271+
- Test with multiple dimensions
272+
- Test with various `n_percentiles` and `palette` values
273+
274+
4. **Documentation**:
275+
- Update method docstring
276+
- Add parameter descriptions
277+
- Include example in docstring or documentation
278+
279+
5. **Validation**:
280+
- Visual comparison with base model gradient output
281+
- Ensure color/alpha mapping matches
282+
- Test with real-world model fits
283+
284+
## Related Research
285+
286+
No previous research documents found (no `thoughts/` directory in repository).
287+
288+
## Open Questions
289+
290+
1. **Should gradient replace or complement HDI bands?**
291+
- Current base model: Users choose either HDI OR gradient
292+
- Plot Suite option: Allow both simultaneously (gradient + HDI overlay)?
293+
294+
2. **Dimension-specific gradients?**
295+
- Should gradient settings be customizable per dimension?
296+
- Or single global setting for all subplots?
297+
298+
3. **Parameter naming consistency?**
299+
- Should we match base model parameter names exactly?
300+
- Or adapt to Plot Suite conventions?
301+
302+
4. **original_scale support?**
303+
- Should we also add `original_scale` parameter to Plot Suite?
304+
- This would require access to model's scale transformation logic
305+
306+
5. **hdi_list vs hdi_prob?**
307+
- Base model supports multiple HDI levels via `hdi_list`
308+
- Plot Suite currently has single `hdi_prob`
309+
- Should we unify these?
310+
311+
## Next Steps
312+
313+
1. **Implement `_add_gradient_to_axes()` helper method** in MMMPlotSuite
314+
2. **Modify `posterior_predictive()` method** to accept gradient parameters
315+
3. **Add comprehensive tests** covering gradient visualization
316+
4. **Visual validation** with example models
317+
5. **Consider migrating other base model features** (`original_scale`, `hdi_list`)
318+
319+
## Additional Context
320+
321+
### Similar Patterns in Codebase
322+
323+
The saturation curves method (`pymc_marketing/mmm/plot.py:744-996`) demonstrates a similar pattern of complex layered visualization:
324+
- Uses `plot_samples()` for sample curves
325+
- Uses `plot_hdi()` for HDI bands
326+
- Layers visualization elements in specific order
327+
- Could serve as reference for gradient implementation
328+
329+
### Color Palette Options
330+
331+
The gradient implementation uses matplotlib colormaps. Common options:
332+
- "Blues" (default) - Blue gradient
333+
- "Reds" - Red gradient
334+
- "Greens" - Green gradient
335+
- "viridis", "plasma", "inferno" - Perceptually uniform colormaps
336+
337+
### Performance Considerations
338+
339+
Gradient rendering with 30 percentile ranges creates 29 `fill_between()` calls per subplot. For models with many dimensions, this could impact rendering performance. Consider:
340+
- Caching computed percentiles
341+
- Reducing default `n_percentiles` if needed
342+
- Providing `add_gradient=False` as default for backward compatibility

0 commit comments

Comments
 (0)