88import numpy as np
99import numpy .typing as npt
1010import pytensor .tensor as pt
11+ from numba import jit
1112from pytensor .tensor .variable import Variable
1213from scipy .interpolate import griddata
1314from scipy .signal import savgol_filter
14- from scipy .stats import norm , pearsonr
15+ from scipy .stats import norm
1516
1617from .tree import Tree
1718
@@ -700,8 +701,9 @@ def plot_variable_importance( # noqa: PLR0915
700701 method : str = "VI" ,
701702 figsize : Optional [Tuple [float , float ]] = None ,
702703 xlabel_angle : float = 0 ,
703- samples : int = 100 ,
704+ samples : int = 50 ,
704705 random_seed : Optional [int ] = None ,
706+ plot_kwargs : Optional [Dict [str , Any ]] = None ,
705707 ax : Optional [plt .Axes ] = None ,
706708) -> Tuple [List [int ], Union [List [plt .Axes ], Any ]]:
707709 """
@@ -733,6 +735,14 @@ def plot_variable_importance( # noqa: PLR0915
733735 Number of predictions used to compute correlation for subsets of variables. Defaults to 100
734736 random_seed : Optional[int]
735737 random_seed used to sample from the posterior. Defaults to None.
738+ plot_kwargs : dict
739+ Additional keyword arguments for the plot. Defaults to None.
740+ Valid keys are:
741+ - color_r2: matplotlib valid color for error bars
742+ - marker_r2: matplotlib valid marker for the mean R squared
743+ - marker_fc_r2: matplotlib valid marker face color for the mean R squared
744+ - ls_ref: matplotlib valid linestyle for the reference line
745+ - color_ref: matplotlib valid color for the reference line
736746 ax : axes
737747 Matplotlib axes.
738748
@@ -745,6 +755,9 @@ def plot_variable_importance( # noqa: PLR0915
745755
746756 all_trees = bartrv .owner .op .all_trees
747757
758+ if plot_kwargs is None :
759+ plot_kwargs = {}
760+
748761 if bartrv .ndim == 1 : # type: ignore
749762 shape = 1
750763 else :
@@ -773,6 +786,10 @@ def plot_variable_importance( # noqa: PLR0915
773786 all_trees , X = X , rng = rng , size = samples , excluded = None , shape = shape
774787 )
775788
789+ r_2_ref = np .array (
790+ [pearsonr2 (predicted_all [j ], predicted_all [j + 1 ]) for j in range (samples - 1 )]
791+ )
792+
776793 if method == "VI" :
777794 idxs = np .argsort (
778795 idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
@@ -794,10 +811,7 @@ def plot_variable_importance( # noqa: PLR0915
794811 shape = shape ,
795812 )
796813 r_2 = np .array (
797- [
798- pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ] ** 2
799- for j in range (samples )
800- ]
814+ [pearsonr2 (predicted_all [j ], predicted_subset [j ]) for j in range (samples )]
801815 )
802816 r2_mean [idx ] = np .mean (r_2 )
803817 r2_hdi [idx ] = az .hdi (r_2 )
@@ -833,10 +847,7 @@ def plot_variable_importance( # noqa: PLR0915
833847 # Calculate Pearson correlation for each sample and find the mean
834848 r_2 = np .zeros (samples )
835849 for j in range (samples ):
836- r_2 [j ] = (
837- (pearsonr (predicted_all [j ].flatten (), predicted_subset [j ].flatten ())[0 ])
838- ** 2
839- )
850+ r_2 [j ] = pearsonr2 (predicted_all [j ], predicted_subset [j ])
840851 mean_r_2 = np .mean (r_2 , dtype = float )
841852 # Identify the least important combination of variables
842853 # based on the maximum mean squared Pearson correlation
@@ -872,9 +883,21 @@ def plot_variable_importance( # noqa: PLR0915
872883 ticks ,
873884 r2_mean ,
874885 np .array ((r2_yerr_min , r2_yerr_max )),
875- color = "C0" ,
886+ color = plot_kwargs .get ("color_r2" , "k" ),
887+ fmt = plot_kwargs .get ("marker_r2" , "o" ),
888+ mfc = plot_kwargs .get ("marker_fc_r2" , "white" ),
889+ )
890+ ax .axhline (
891+ np .mean (r_2_ref ),
892+ ls = plot_kwargs .get ("ls_ref" , "--" ),
893+ color = plot_kwargs .get ("color_ref" , "grey" ),
894+ )
895+ ax .fill_between (
896+ [- 0.5 , n_vars - 0.5 ],
897+ * az .hdi (r_2_ref ),
898+ alpha = 0.1 ,
899+ color = plot_kwargs .get ("color_ref" , "grey" ),
876900 )
877- ax .axhline (r2_mean [- 1 ], ls = "--" , color = "0.5" )
878901 ax .set_xticks (ticks , new_labels , rotation = xlabel_angle )
879902 ax .set_ylabel ("R²" , rotation = 0 , labelpad = 12 )
880903 ax .set_ylim (0 , 1 )
@@ -890,3 +913,13 @@ def generate_sequences(n_vars, i_var, include):
890913 else :
891914 sequences = [()]
892915 return sequences
916+
917+
918+ @jit (nopython = True )
919+ def pearsonr2 (A , B ):
920+ """Compute the squared Pearson correlation coefficient"""
921+ A = A .flatten ()
922+ B = B .flatten ()
923+ am = A - np .mean (A )
924+ bm = B - np .mean (B )
925+ return (am @ bm ) ** 2 / (np .sum (am ** 2 ) * np .sum (bm ** 2 ))
0 commit comments