File tree Expand file tree Collapse file tree 1 file changed +22
-1
lines changed Expand file tree Collapse file tree 1 file changed +22
-1
lines changed Original file line number Diff line number Diff line change @@ -360,7 +360,11 @@ def plot_prior_predictive(self, **plt_kwargs: Any) -> plt.Figure:
360360 return fig
361361
362362 def plot_posterior_predictive (
363- self , original_scale : bool = False , ax : plt .Axes = None , ** plt_kwargs : Any
363+ self ,
364+ original_scale : bool = False ,
365+ add_mean : bool = True ,
366+ ax : plt .Axes = None ,
367+ ** plt_kwargs : Any ,
364368 ) -> plt .Figure :
365369 """Plot posterior distribution from the model fit.
366370
@@ -425,6 +429,23 @@ def plot_posterior_predictive(
425429 label = f"{ hdi_prob :.0%} HDI" ,
426430 )
427431
432+ if add_mean :
433+ mean_prediction = posterior_predictive_data [self .output_var ].mean (
434+ dim = ["chain" , "draw" ]
435+ )
436+
437+ if original_scale :
438+ mean_prediction = transform_1d_array (
439+ self .get_target_transformer ().inverse_transform , mean_prediction
440+ )
441+
442+ ax .plot (
443+ np .asarray (posterior_predictive_data .date ),
444+ mean_prediction ,
445+ color = "C0" ,
446+ label = "Mean Prediction" ,
447+ )
448+
428449 ax .plot (
429450 np .asarray (posterior_predictive_data .date ),
430451 target_to_plot ,
You can’t perform that action at this time.
0 commit comments