Skip to content

Commit afd1a99

Browse files
committed
implemented 'add_mean' in plot_posterior_predictive
1 parent bc4b248 commit afd1a99

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

pymc_marketing/mmm/base.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,11 @@ def plot_prior_predictive(self, **plt_kwargs: Any) -> plt.Figure:
360360
return fig
361361

362362
def plot_posterior_predictive(
363-
self, original_scale: bool = False, ax: plt.Axes = None, **plt_kwargs: Any
363+
self,
364+
original_scale: bool = False,
365+
add_mean: bool = True,
366+
ax: plt.Axes = None,
367+
**plt_kwargs: Any,
364368
) -> plt.Figure:
365369
"""Plot posterior distribution from the model fit.
366370
@@ -425,6 +429,23 @@ def plot_posterior_predictive(
425429
label=f"{hdi_prob:.0%} HDI",
426430
)
427431

432+
if add_mean:
433+
mean_prediction = posterior_predictive_data[self.output_var].mean(
434+
dim=["chain", "draw"]
435+
)
436+
437+
if original_scale:
438+
mean_prediction = transform_1d_array(
439+
self.get_target_transformer().inverse_transform, mean_prediction
440+
)
441+
442+
ax.plot(
443+
np.asarray(posterior_predictive_data.date),
444+
mean_prediction,
445+
color="C0",
446+
label="Mean Prediction",
447+
)
448+
428449
ax.plot(
429450
np.asarray(posterior_predictive_data.date),
430451
target_to_plot,

0 commit comments

Comments
 (0)