44import warnings
55from typing import Any , Callable , Optional , Union
66
7- import arviz as az
87import matplotlib .pyplot as plt
98import numpy as np
109import numpy .typing as npt
1110import pytensor .tensor as pt
11+ from arviz_base import rcParams
12+ from arviz_stats .base import array_stats
1213from numba import jit
1314from pytensor .tensor .variable import Variable
1415from scipy .interpolate import griddata
1516from scipy .signal import savgol_filter
16- from scipy .stats import norm
1717
1818from .tree import Tree
1919
@@ -76,12 +76,12 @@ def _sample_posterior(
7676
7777
7878def plot_convergence (
79- idata : az . InferenceData ,
79+ idata : Any ,
8080 var_name : Optional [str ] = None ,
8181 kind : str = "ecdf" ,
8282 figsize : Optional [tuple [float , float ]] = None ,
8383 ax = None ,
84- ) -> list [ plt . Axes ] :
84+ ) -> None :
8585 """
8686 Plot convergence diagnostics.
8787
@@ -102,39 +102,12 @@ def plot_convergence(
102102 -------
103103 list[ax] : matplotlib axes
104104 """
105- ess_threshold = idata ["posterior" ]["chain" ].size * 100
106- ess = np .atleast_2d (az .ess (idata , method = "bulk" , var_names = var_name )[var_name ].values )
107- rhat = np .atleast_2d (az .rhat (idata , var_names = var_name )[var_name ].values )
108-
109- if figsize is None :
110- figsize = (10 , 3 )
111-
112- if kind == "ecdf" :
113- kind_func : Callable [..., Any ] = az .plot_ecdf
114- sharey = True
115- elif kind == "kde" :
116- kind_func = az .plot_kde
117- sharey = False
118-
119- if ax is None :
120- _ , ax = plt .subplots (1 , 2 , figsize = figsize , sharex = "col" , sharey = sharey )
121-
122- for idx , (essi , rhati ) in enumerate (zip (ess , rhat )):
123- kind_func (essi , ax = ax [0 ], plot_kwargs = {"color" : f"C{ idx } " })
124- kind_func (rhati , ax = ax [1 ], plot_kwargs = {"color" : f"C{ idx } " })
125-
126- ax [0 ].axvline (ess_threshold , color = "0.7" , ls = "--" )
127- # Assume Rhats are N(1, 0.005) iid. Then compute the 0.99 quantile
128- # scaled by the sample size and use it as a threshold.
129- ax [1 ].axvline (norm (1 , 0.005 ).ppf (0.99 ** (1 / ess .size )), color = "0.7" , ls = "--" )
130-
131- ax [0 ].set_xlabel ("ESS" )
132- ax [1 ].set_xlabel ("R-hat" )
133- if kind == "kde" :
134- ax [0 ].set_yticks ([])
135- ax [1 ].set_yticks ([])
136-
137- return ax
105+ warnings .warn (
106+ "This function has been deprecated"
107+ "Use az.plot_convergence_dist() instead."
108+ "https://arviz-plots.readthedocs.io/en/latest/api/generated/arviz_plots.plot_convergence_dist.html" ,
109+ FutureWarning ,
110+ )
138111
139112
140113def plot_ice (
@@ -408,7 +381,7 @@ def identity(x):
408381 if var in var_discrete :
409382 _ , idx_uni = np .unique (new_x , return_index = True )
410383 y_means = p_di .mean (0 )[idx_uni ]
411- hdi = az .hdi (p_di )[idx_uni ]
384+ hdi = array_stats .hdi (p_di , prob = rcParams [ "stats.ci_prob" ], axis = 0 )[idx_uni ]
412385 axes [count ].errorbar (
413386 new_x [idx_uni ],
414387 y_means ,
@@ -418,11 +391,13 @@ def identity(x):
418391 )
419392 axes [count ].set_xticks (new_x [idx_uni ])
420393 else :
421- az . plot_hdi (
394+ _plot_hdi (
422395 new_x ,
423396 p_di ,
424397 smooth = smooth ,
425- fill_kwargs = {"alpha" : alpha , "color" : color },
398+ alpha = alpha ,
399+ color = color ,
400+ smooth_kwargs = smooth_kwargs ,
426401 ax = axes [count ],
427402 )
428403 if smooth :
@@ -659,7 +634,7 @@ def _create_pdp_data(
659634def _smooth_mean (
660635 new_x : npt .NDArray ,
661636 p_di : npt .NDArray ,
662- kind : str = "pdp " ,
637+ kind : str = "neutral " ,
663638 smooth_kwargs : Optional [dict [str , Any ]] = None ,
664639) -> tuple [np .ndarray , np .ndarray ]:
665640 """
@@ -688,7 +663,10 @@ def _smooth_mean(
688663 smooth_kwargs .setdefault ("polyorder" , 2 )
689664 x_data = np .linspace (np .nanmin (new_x ), np .nanmax (new_x ), 200 )
690665 x_data [0 ] = (x_data [0 ] + x_data [1 ]) / 2
691- if kind == "pdp" :
666+
667+ if kind == "neutral" :
668+ interp = griddata (new_x , p_di , x_data )
669+ elif kind == "pdp" :
692670 interp = griddata (new_x , p_di .mean (0 ), x_data )
693671 else :
694672 interp = griddata (new_x , p_di .T , x_data )
@@ -800,7 +778,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
800778
801779
802780def compute_variable_importance ( # noqa: PLR0915 PLR0912
803- idata : az . InferenceData ,
781+ idata : Any ,
804782 bartrv : Variable ,
805783 X : npt .NDArray ,
806784 method : str = "VI" ,
@@ -904,7 +882,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
904882 [pearsonr2 (predicted_all [j ], predicted_subset [j ]) for j in range (samples )]
905883 )
906884 r2_mean [idx ] = np .mean (r_2 )
907- r2_hdi [idx ] = az .hdi (r_2 )
885+ r2_hdi [idx ] = array_stats .hdi (r_2 , prob = rcParams [ "stats.ci_prob" ] )
908886 preds [idx ] = predicted_subset .squeeze ()
909887
910888 if method in ["backward" , "backward_VI" ]:
@@ -954,7 +932,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
954932
955933 # Save values for plotting later
956934 r2_mean [i_var - init ] = max_r_2
957- r2_hdi [i_var - init ] = az .hdi (r_2_without_least_important_vars )
935+ r2_hdi [i_var - init ] = array_stats .hdi (r_2_without_least_important_vars )
958936 preds [i_var - init ] = least_important_samples .squeeze ()
959937
960938 # extend current list of least important variable
@@ -1079,7 +1057,7 @@ def plot_variable_importance(
10791057 )
10801058 ax .fill_between (
10811059 [- 0.5 , n_vars - 0.5 ],
1082- * az .hdi (r_2_ref ),
1060+ * array_stats .hdi (r_2_ref , prob = rcParams [ "stats.ci_prob" ] ),
10831061 alpha = 0.1 ,
10841062 color = plot_kwargs .get ("color_ref" , "grey" ),
10851063 )
@@ -1229,3 +1207,22 @@ def pearsonr2(A, B):
12291207 am = A - np .mean (A )
12301208 bm = B - np .mean (B )
12311209 return (am @ bm ) ** 2 / (np .sum (am ** 2 ) * np .sum (bm ** 2 ))
1210+
1211+
1212+ def _plot_hdi (x , y , smooth , color , alpha , smooth_kwargs , ax ):
1213+ x = np .asarray (x )
1214+ y = np .asarray (y )
1215+ hdi_prob = rcParams ["stats.ci_prob" ]
1216+ hdi_data = array_stats .hdi (y , hdi_prob , axis = 0 )
1217+ if smooth :
1218+ if isinstance (x [0 ], np .datetime64 ):
1219+ raise TypeError ("Cannot deal with x as type datetime. Recommend setting smooth=False." )
1220+
1221+ x_data , y_data = _smooth_mean (x , hdi_data , smooth_kwargs = smooth_kwargs )
1222+ else :
1223+ idx = np .argsort (x )
1224+ x_data = x [idx ]
1225+ y_data = hdi_data [idx ]
1226+
1227+ ax .fill_between (x_data , y_data [:, 0 ], y_data [:, 1 ], color = color , alpha = alpha )
1228+ return ax
0 commit comments