@@ -382,13 +382,9 @@ def plot_posterior_predictive(
382382 plt.Figure
383383
384384 """
385- try :
386- posterior_predictive_data : Dataset = self .posterior_predictive
387-
388- except Exception as e :
389- raise RuntimeError (
390- "Make sure the model has bin fitted and the posterior predictive has been sampled!"
391- ) from e
385+ posterior_predictive_data : Dataset = self ._get_posterior_predictive_data (
386+ original_scale = original_scale
387+ )
392388
393389 target_to_plot = np .asarray (
394390 self .y
@@ -408,13 +404,6 @@ def plot_posterior_predictive(
408404 else :
409405 fig = ax .figure
410406
411- if original_scale :
412- posterior_predictive_data = apply_sklearn_transformer_across_dim (
413- data = posterior_predictive_data ,
414- func = self .get_target_transformer ().inverse_transform ,
415- dim_name = "date" ,
416- )
417-
418407 for hdi_prob , alpha in zip ((0.94 , 0.50 ), (0.2 , 0.4 ), strict = True ):
419408 likelihood_hdi : DataArray = az .hdi (
420409 ary = posterior_predictive_data , hdi_prob = hdi_prob
@@ -430,15 +419,8 @@ def plot_posterior_predictive(
430419 )
431420
432421 if add_mean :
433- mean_prediction = posterior_predictive_data [self .output_var ].mean (
434- dim = ["chain" , "draw" ]
435- )
436-
437- ax .plot (
438- np .asarray (posterior_predictive_data .date ),
439- mean_prediction ,
440- color = "C0" ,
441- label = "Mean Prediction" ,
422+ ax = self ._add_mean_to_plot (
423+ ax = ax , original_scale = original_scale , color = "red"
442424 )
443425
444426 ax .plot (
@@ -456,6 +438,45 @@ def plot_posterior_predictive(
456438
457439 return fig
458440
441+ def _get_posterior_predictive_data (self , original_scale : bool = False ) -> Dataset :
442+ """Get the posterior predictive data."""
443+ try :
444+ posterior_predictive_data : Dataset = self .posterior_predictive
445+
446+ except Exception as e :
447+ raise RuntimeError (
448+ "Make sure the model has bin fitted and the posterior predictive has been sampled!"
449+ ) from e
450+
451+ if original_scale :
452+ posterior_predictive_data = apply_sklearn_transformer_across_dim (
453+ data = posterior_predictive_data ,
454+ func = self .get_target_transformer ().inverse_transform ,
455+ dim_name = "date" ,
456+ )
457+ return posterior_predictive_data
458+
459+ def _add_mean_to_plot (
460+ self , ax , original_scale : bool = False , color = "blue" , linestyle = "-" , ** kwargs
461+ ) -> plt .Axes :
462+ """Add mean prediction to existing plot."""
463+ posterior_predictive_data : Dataset = self ._get_posterior_predictive_data (
464+ original_scale = original_scale
465+ )
466+
467+ mean_prediction = posterior_predictive_data [self .output_var ].mean (
468+ dim = ["chain" , "draw" ]
469+ )
470+
471+ ax .plot (
472+ np .asarray (posterior_predictive_data .date ),
473+ mean_prediction ,
474+ color = color ,
475+ linestyle = linestyle ,
476+ label = "Mean Prediction" ,
477+ )
478+ return ax
479+
459480 def get_errors (self , original_scale : bool = False ) -> DataArray :
460481 """Get model errors posterior distribution.
461482
0 commit comments