@@ -693,6 +693,50 @@ def _smooth_mean(
693693 return x_data , y_data
694694
695695
696+ def get_variable_inclusion (idata , X , labels = None , to_kulprit = False ):
697+ """
698+ Get the normalized variable inclusion from BART model.
699+
700+ Parameters
701+ ----------
702+ idata : InferenceData
703+ InferenceData containing a collection of BART_trees in sample_stats group
704+ X : npt.NDArray
705+ The covariate matrix.
706+ labels : Optional[list[str]]
707+ List of the names of the covariates. If X is a DataFrame the names of the covariables will
708+ be taken from it and this argument will be ignored.
709+ to_kulprit : bool
710+ If True, the function will return a list of list with the variables names.
711+ This list can be passed as a path to Kulprit's project method. Defaults to False.
712+ Returns
713+ -------
714+ VI_norm : npt.NDArray
715+ Normalized variable inclusion.
716+ labels : list[str]
717+ List of the names of the covariates.
718+ """
719+ VIs = idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
720+ VI_norm = VIs / VIs .sum ()
721+ idxs = np .argsort (VI_norm )
722+
723+ indices = idxs [::- 1 ]
724+ n_vars = len (indices )
725+
726+ if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
727+ labels = X .columns
728+
729+ if labels is None :
730+ labels = np .arange (n_vars ).astype (str )
731+
732+ label_list = labels .to_list ()
733+
734+ if to_kulprit :
735+ return [label_list [:idx ] for idx in range (n_vars )]
736+ else :
737+ return VI_norm [indices ], label_list
738+
739+
696740def plot_variable_inclusion (idata , X , labels = None , figsize = None , plot_kwargs = None , ax = None ):
697741 """
698742 Plot normalized variable inclusion from BART model.
@@ -720,26 +764,15 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
720764
721765 Returns
722766 -------
723- idxs: indexes of the covariates from higher to lower relative importance
724767 axes: matplotlib axes
725768 """
726769 if plot_kwargs is None :
727770 plot_kwargs = {}
728771
729- VIs = idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
730- VIs = VIs / VIs .sum ()
731- idxs = np .argsort (VIs )
732-
733- indices = idxs [::- 1 ]
734- n_vars = len (indices )
735-
736- if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
737- labels = X .columns
772+ VI_norm , labels = get_variable_inclusion (idata , X , labels )
773+ n_vars = len (labels )
738774
739- if labels is None :
740- labels = np .arange (n_vars ).astype (str )
741-
742- new_labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels [indices ])]
775+ new_labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )]
743776
744777 ticks = np .arange (n_vars , dtype = int )
745778
@@ -749,19 +782,18 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
749782 if ax is None :
750783 _ , ax = plt .subplots (1 , 1 , figsize = figsize )
751784
785+ ax .axhline (1 / n_vars , color = "0.5" , linestyle = "--" )
752786 ax .plot (
753- VIs [ indices ] ,
787+ VI_norm ,
754788 color = plot_kwargs .get ("color" , "k" ),
755789 marker = plot_kwargs .get ("marker" , "o" ),
756790 ls = plot_kwargs .get ("ls" , "-" ),
757791 )
758792
759793 ax .set_xticks (ticks , new_labels , rotation = plot_kwargs .get ("rotation" , 0 ))
760-
761- ax .axhline (1 / n_vars , color = "0.5" , linestyle = "--" )
762794 ax .set_ylim (0 , 1 )
763795
764- return idxs , ax
796+ return ax
765797
766798
767799def compute_variable_importance ( # noqa: PLR0915 PLR0912
0 commit comments