|
| 1 | +--- |
| 2 | +date: 2025-11-04T22:26:50Z |
| 3 | +researcher: Claude Code |
| 4 | +git_commit: 9537b9a08837a3c5dabcdee6244a0cd1c4688ea0 |
| 5 | +branch: work-issue-2054 |
| 6 | +repository: pymc-labs/pymc-marketing |
| 7 | +topic: "Add Gradient to plot posterior predictive in Plot Suite" |
| 8 | +tags: [research, codebase, plot-suite, gradient, visualization, mmm, issue-2054] |
| 9 | +status: complete |
| 10 | +last_updated: 2025-11-04 |
| 11 | +last_updated_by: Claude Code |
| 12 | +issue_number: 2054 |
| 13 | +--- |
| 14 | + |
| 15 | +# Research: Add Gradient to plot posterior predictive in Plot Suite |
| 16 | + |
| 17 | +**Date**: 2025-11-04T22:26:50Z |
| 18 | +**Researcher**: Claude Code |
| 19 | +**Git Commit**: 9537b9a08837a3c5dabcdee6244a0cd1c4688ea0 |
| 20 | +**Branch**: work-issue-2054 |
| 21 | +**Repository**: pymc-labs/pymc-marketing |
| 22 | +**Issue**: #2054 |
| 23 | + |
| 24 | +## Research Question |
| 25 | + |
| 26 | +How to add Gradient visualization functionality to the `plot_posterior_predictive` method in the Plot Suite to enable full migration from base model plotting methods? |
| 27 | + |
| 28 | +## Summary |
| 29 | + |
| 30 | +The gradient visualization feature currently exists in the **BaseValidateMMM** class but is **not available in the MMMPlotSuite**. This creates an incomplete migration path because: |
| 31 | + |
| 32 | +1. **BaseValidateMMM.plot_posterior_predictive()** (`pymc_marketing/mmm/base.py:625`) supports `add_gradient` parameter |
| 33 | +2. **MMMPlotSuite.posterior_predictive()** (`pymc_marketing/mmm/plot.py:375`) does NOT support gradient visualization |
| 34 | +3. The gradient implementation in BaseValidateMMM uses `_add_gradient_to_plot()` (`pymc_marketing/mmm/base.py:362-433`) |
| 35 | + |
| 36 | +To complete the migration, the gradient functionality needs to be added to the Plot Suite's `posterior_predictive()` method. |
| 37 | + |
| 38 | +## Detailed Findings |
| 39 | + |
| 40 | +### 1. Current Plot Suite Architecture |
| 41 | + |
| 42 | +**Location**: `pymc_marketing/mmm/plot.py:187-1923` |
| 43 | + |
| 44 | +The `MMMPlotSuite` class provides a comprehensive plotting API for Media Mix Models: |
| 45 | + |
| 46 | +```python |
| 47 | +class MMMPlotSuite: |
| 48 | + """Media Mix Model Plot Suite.""" |
| 49 | + |
| 50 | + def __init__(self, idata: xr.Dataset | az.InferenceData): |
| 51 | + self.idata = idata |
| 52 | +``` |
| 53 | + |
| 54 | +**Integration**: Exposed via property in `pymc_marketing/mmm/multidimensional.py:618-623`: |
| 55 | + |
| 56 | +```python |
| 57 | +@property |
| 58 | +def plot(self) -> MMMPlotSuite: |
| 59 | + """Use the MMMPlotSuite to plot the results.""" |
| 60 | + return MMMPlotSuite(idata=self.idata) |
| 61 | +``` |
| 62 | + |
| 63 | +**Access Pattern**: Users call `mmm.plot.method_name()` on fitted models. |
| 64 | + |
| 65 | +### 2. Posterior Predictive Plotting - Two APIs |
| 66 | + |
| 67 | +#### API 1: MMMPlotSuite.posterior_predictive() (Plot Suite) |
| 68 | +**Location**: `pymc_marketing/mmm/plot.py:375-463` |
| 69 | + |
| 70 | +**Current Signature**: |
| 71 | +```python |
| 72 | +def posterior_predictive( |
| 73 | + self, |
| 74 | + var: list[str] | None = None, |
| 75 | + idata: xr.Dataset | None = None, |
| 76 | + hdi_prob: float = 0.85, |
| 77 | +) -> tuple[Figure, NDArray[Axes]]: |
| 78 | +``` |
| 79 | + |
| 80 | +**Features**: |
| 81 | +- Multi-variable plotting |
| 82 | +- Multi-dimensional subplot support |
| 83 | +- HDI bands at configurable probability |
| 84 | +- Median line visualization |
| 85 | +- **Missing: Gradient visualization** |
| 86 | + |
| 87 | +**Test Coverage**: `tests/mmm/test_plot.py:185` |
| 88 | + |
| 89 | +#### API 2: BaseValidateMMM.plot_posterior_predictive() (Base Model) |
| 90 | +**Location**: `pymc_marketing/mmm/base.py:625-682` |
| 91 | + |
| 92 | +**Current Signature**: |
| 93 | +```python |
| 94 | +def plot_posterior_predictive( |
| 95 | + self, |
| 96 | + original_scale: bool = False, |
| 97 | + hdi_list: list[float] | None = None, |
| 98 | + add_mean: bool = True, |
| 99 | + add_gradient: bool = False, |
| 100 | + ax: plt.Axes | None = None, |
| 101 | + **plt_kwargs, |
| 102 | +) -> plt.Figure: |
| 103 | +``` |
| 104 | + |
| 105 | +**Features**: |
| 106 | +- Scale transformation (`original_scale`) |
| 107 | +- Multiple HDI levels (`hdi_list`) |
| 108 | +- Mean line (`add_mean`) |
| 109 | +- **Gradient visualization (`add_gradient`)** ← This is what needs to be migrated |
| 110 | +- Custom axes support |
| 111 | + |
| 112 | +**Test Coverage**: `tests/mmm/test_plotting.py:206-263` (extensive parametrized tests) |
| 113 | + |
| 114 | +### 3. Gradient Implementation Details |
| 115 | + |
| 116 | +**Core Implementation**: `pymc_marketing/mmm/base.py:362-433` |
| 117 | + |
| 118 | +```python |
| 119 | +def _add_gradient_to_plot( |
| 120 | + self, |
| 121 | + ax: plt.Axes, |
| 122 | + group: Literal["prior_predictive", "posterior_predictive"], |
| 123 | + original_scale: bool = False, |
| 124 | + n_percentiles: int = 30, |
| 125 | + palette: str = "Blues", |
| 126 | + **kwargs, |
| 127 | +) -> plt.Axes: |
| 128 | + """ |
| 129 | + Add a gradient representation of the prior or posterior predictive distribution. |
| 130 | +
|
| 131 | + Creates a shaded area plot where color intensity represents |
| 132 | + the density of the posterior predictive distribution. |
| 133 | + """ |
| 134 | +``` |
| 135 | + |
| 136 | +**Algorithm**: |
| 137 | +1. Retrieves posterior_predictive data and flattens samples |
| 138 | +2. Computes percentile ranges (default: 30 ranges from 3rd to 97th percentile) |
| 139 | +3. Creates layered `fill_between()` calls with varying colors and alpha |
| 140 | +4. Middle percentiles use higher alpha (denser distribution) |
| 141 | +5. Outer percentiles use lower alpha (sparser distribution) |
| 142 | +6. Color mapping via matplotlib colormap (default "Blues") |
| 143 | + |
| 144 | +**Visual Effect**: Creates a smooth gradient visualization showing full distribution density. |
| 145 | + |
| 146 | +**Usage in Base Model** (`pymc_marketing/mmm/base.py:534-541`): |
| 147 | +```python |
| 148 | +if add_gradient: |
| 149 | + ax = self._add_gradient_to_plot( |
| 150 | + ax=ax, |
| 151 | + group=group, |
| 152 | + original_scale=original_scale, |
| 153 | + n_percentiles=30, |
| 154 | + palette="Blues", |
| 155 | + ) |
| 156 | +``` |
| 157 | + |
| 158 | +### 4. Test Coverage for Gradient Feature |
| 159 | + |
| 160 | +**Location**: `tests/mmm/test_plotting.py:206-263` |
| 161 | + |
| 162 | +Tests include combinations of: |
| 163 | +- `add_gradient: True` with various other parameters |
| 164 | +- Prior predictive plots (lines 160, 181, 189, 197) |
| 165 | +- Posterior predictive plots (lines 219, 240, 248, 256) |
| 166 | +- Combinations with `add_mean`, `original_scale`, `hdi_list` |
| 167 | + |
| 168 | +Example test cases: |
| 169 | +```python |
| 170 | +("plot_posterior_predictive", {"add_gradient": True}), |
| 171 | +("plot_posterior_predictive", {"add_gradient": True, "original_scale": True}), |
| 172 | +("plot_posterior_predictive", {"add_gradient": True, "add_mean": False}), |
| 173 | +``` |
| 174 | + |
| 175 | +### 5. Migration Context |
| 176 | + |
| 177 | +**Current State**: |
| 178 | +- BaseValidateMMM methods: Full-featured but older API pattern |
| 179 | +- MMMPlotSuite: Modern API with better multi-dimensional support but missing gradient |
| 180 | + |
| 181 | +**Migration Goal**: Enable users to get all functionality through the Plot Suite API: |
| 182 | +- Before: `mmm.plot_posterior_predictive(add_gradient=True)` |
| 183 | +- After: `mmm.plot.posterior_predictive(add_gradient=True)` or similar |
| 184 | + |
| 185 | +**Blockers for Full Migration**: |
| 186 | +1. Gradient visualization not available in Plot Suite |
| 187 | +2. `original_scale` parameter not in Plot Suite |
| 188 | +3. Multiple HDI levels (`hdi_list`) not in Plot Suite (currently single `hdi_prob`) |
| 189 | + |
| 190 | +## Code References |
| 191 | + |
| 192 | +### Key Implementation Files |
| 193 | +- `pymc_marketing/mmm/plot.py:187` - MMMPlotSuite class |
| 194 | +- `pymc_marketing/mmm/plot.py:375` - MMMPlotSuite.posterior_predictive() method |
| 195 | +- `pymc_marketing/mmm/base.py:362` - _add_gradient_to_plot() implementation |
| 196 | +- `pymc_marketing/mmm/base.py:625` - BaseValidateMMM.plot_posterior_predictive() |
| 197 | +- `pymc_marketing/mmm/multidimensional.py:618` - Plot Suite property accessor |
| 198 | + |
| 199 | +### Helper Methods (Reusable in Implementation) |
| 200 | +- `pymc_marketing/mmm/plot.py:200` - `_init_subplots()` - Subplot grid initialization |
| 201 | +- `pymc_marketing/mmm/plot.py:247` - `_get_additional_dim_combinations()` - Dimension handling |
| 202 | +- `pymc_marketing/mmm/plot.py:269` - `_reduce_and_stack()` - Data reduction |
| 203 | +- `pymc_marketing/mmm/plot.py:286` - `_get_posterior_predictive_data()` - Data retrieval |
| 204 | +- `pymc_marketing/mmm/plot.py:306` - `_add_median_and_hdi()` - Add median/HDI to plot |
| 205 | + |
| 206 | +### Test Files |
| 207 | +- `tests/mmm/test_plot.py:185` - Basic posterior_predictive test |
| 208 | +- `tests/mmm/test_plotting.py:206-263` - Parametrized tests with gradient |
| 209 | +- `tests/mmm/test_base.py:358` - Error handling test |
| 210 | + |
| 211 | +## Architecture Insights |
| 212 | + |
| 213 | +### Design Patterns in Plot Suite |
| 214 | + |
| 215 | +1. **Dimension-Aware Subplots**: Plot methods automatically create subplots for each combination of non-ignored dimensions |
| 216 | +2. **Helper Method Composition**: Complex plotting logic decomposed into reusable helpers |
| 217 | +3. **xarray Integration**: Heavy use of xarray for multi-dimensional data manipulation |
| 218 | +4. **Tuple Returns**: Methods return `(Figure, NDArray[Axes])` for flexibility |
| 219 | +5. **No State Mutation**: Plot Suite is stateless, only operates on InferenceData |
| 220 | + |
| 221 | +### Gradient Implementation Pattern |
| 222 | + |
| 223 | +The gradient visualization follows a **layered percentile** approach: |
| 224 | +- **Conceptual**: Stack many thin HDI bands with varying opacity |
| 225 | +- **Visual**: Creates smooth density gradient from sparse (edges) to dense (center) |
| 226 | +- **Technical**: Uses `np.percentile()` + `ax.fill_between()` in loop |
| 227 | +- **Customization**: Configurable via `n_percentiles` and `palette` |
| 228 | + |
| 229 | +### Integration Approach for Plot Suite |
| 230 | + |
| 231 | +**Recommended Pattern**: Add `add_gradient` parameter to `MMMPlotSuite.posterior_predictive()` |
| 232 | + |
| 233 | +```python |
| 234 | +def posterior_predictive( |
| 235 | + self, |
| 236 | + var: list[str] | None = None, |
| 237 | + idata: xr.Dataset | None = None, |
| 238 | + hdi_prob: float = 0.85, |
| 239 | + add_gradient: bool = False, # NEW |
| 240 | + n_percentiles: int = 30, # NEW |
| 241 | + palette: str = "Blues", # NEW |
| 242 | +) -> tuple[Figure, NDArray[Axes]]: |
| 243 | +``` |
| 244 | + |
| 245 | +**Implementation Strategy**: |
| 246 | +1. Extract gradient logic from `BaseValidateMMM._add_gradient_to_plot()` into standalone function |
| 247 | +2. Adapt to work with xarray DataArrays (Plot Suite uses xarray, base uses Dataset) |
| 248 | +3. Integrate gradient plotting into dimension loop in `posterior_predictive()` |
| 249 | +4. Place gradient layer BEFORE median/HDI visualization (background layer) |
| 250 | +5. Add conditional logic: `if add_gradient:` section before `_add_median_and_hdi()` |
| 251 | + |
| 252 | +**Key Adaptation**: The base model gradient method works with dates directly from Dataset, while Plot Suite works with dimension-sliced DataArrays. Need to handle this in the adaptation. |
| 253 | + |
| 254 | +## Implementation Checklist |
| 255 | + |
| 256 | +To add gradient to Plot Suite's `posterior_predictive()`: |
| 257 | + |
| 258 | +1. **Create Helper Method**: Add `_add_gradient_to_axes()` in MMMPlotSuite |
| 259 | + - Adapt `BaseValidateMMM._add_gradient_to_plot()` logic |
| 260 | + - Accept xarray DataArray instead of Dataset |
| 261 | + - Work with generic dimensions (not just "date") |
| 262 | + |
| 263 | +2. **Modify `posterior_predictive()` Method**: |
| 264 | + - Add parameters: `add_gradient`, `n_percentiles`, `palette` |
| 265 | + - Add gradient rendering in dimension loop (before median/HDI) |
| 266 | + - Ensure gradient is background layer (drawn first) |
| 267 | + |
| 268 | +3. **Update Tests**: |
| 269 | + - Add test cases in `tests/mmm/test_plot.py` |
| 270 | + - Test with single dimension |
| 271 | + - Test with multiple dimensions |
| 272 | + - Test with various `n_percentiles` and `palette` values |
| 273 | + |
| 274 | +4. **Documentation**: |
| 275 | + - Update method docstring |
| 276 | + - Add parameter descriptions |
| 277 | + - Include example in docstring or documentation |
| 278 | + |
| 279 | +5. **Validation**: |
| 280 | + - Visual comparison with base model gradient output |
| 281 | + - Ensure color/alpha mapping matches |
| 282 | + - Test with real-world model fits |
| 283 | + |
| 284 | +## Related Research |
| 285 | + |
| 286 | +No previous research documents found (no `thoughts/` directory in repository). |
| 287 | + |
| 288 | +## Open Questions |
| 289 | + |
| 290 | +1. **Should gradient replace or complement HDI bands?** |
| 291 | + - Current base model: Users choose either HDI OR gradient |
| 292 | + - Plot Suite option: Allow both simultaneously (gradient + HDI overlay)? |
| 293 | + |
| 294 | +2. **Dimension-specific gradients?** |
| 295 | + - Should gradient settings be customizable per dimension? |
| 296 | + - Or single global setting for all subplots? |
| 297 | + |
| 298 | +3. **Parameter naming consistency?** |
| 299 | + - Should we match base model parameter names exactly? |
| 300 | + - Or adapt to Plot Suite conventions? |
| 301 | + |
| 302 | +4. **original_scale support?** |
| 303 | + - Should we also add `original_scale` parameter to Plot Suite? |
| 304 | + - This would require access to model's scale transformation logic |
| 305 | + |
| 306 | +5. **hdi_list vs hdi_prob?** |
| 307 | + - Base model supports multiple HDI levels via `hdi_list` |
| 308 | + - Plot Suite currently has single `hdi_prob` |
| 309 | + - Should we unify these? |
| 310 | + |
| 311 | +## Next Steps |
| 312 | + |
| 313 | +1. **Implement `_add_gradient_to_axes()` helper method** in MMMPlotSuite |
| 314 | +2. **Modify `posterior_predictive()` method** to accept gradient parameters |
| 315 | +3. **Add comprehensive tests** covering gradient visualization |
| 316 | +4. **Visual validation** with example models |
| 317 | +5. **Consider migrating other base model features** (`original_scale`, `hdi_list`) |
| 318 | + |
| 319 | +## Additional Context |
| 320 | + |
| 321 | +### Similar Patterns in Codebase |
| 322 | + |
| 323 | +The saturation curves method (`pymc_marketing/mmm/plot.py:744-996`) demonstrates a similar pattern of complex layered visualization: |
| 324 | +- Uses `plot_samples()` for sample curves |
| 325 | +- Uses `plot_hdi()` for HDI bands |
| 326 | +- Layers visualization elements in specific order |
| 327 | +- Could serve as reference for gradient implementation |
| 328 | + |
| 329 | +### Color Palette Options |
| 330 | + |
| 331 | +The gradient implementation uses matplotlib colormaps. Common options: |
| 332 | +- "Blues" (default) - Blue gradient |
| 333 | +- "Reds" - Red gradient |
| 334 | +- "Greens" - Green gradient |
| 335 | +- "viridis", "plasma", "inferno" - Perceptually uniform colormaps |
| 336 | + |
| 337 | +### Performance Considerations |
| 338 | + |
| 339 | +Gradient rendering with 30 percentile ranges creates 29 `fill_between()` calls per subplot. For models with many dimensions, this could impact rendering performance. Consider: |
| 340 | +- Caching computed percentiles |
| 341 | +- Reducing default `n_percentiles` if needed |
| 342 | +- Providing `add_gradient=False` as default for backward compatibility |
0 commit comments