44import matplotlib .pyplot as plt
55import numpy as np
66
7+ from aesara .tensor .var import Variable
78from numpy .random import RandomState
89from scipy .interpolate import griddata
910from scipy .signal import savgol_filter
1011from scipy .stats import pearsonr
1112
1213
13- def predict (idata , rng , X , size = None , excluded = None ):
14+ def predict (bartrv , rng , X , size = None , excluded = None ):
1415 """
1516 Generate samples from the BART-posterior.
1617
1718 Parameters
1819 ----------
19- idata : InferenceData
20- InferenceData containing a collection of BART_trees in sample_stats group
20+ bartrv : BART Random Variable
21+ BART variable once the model that include it has been fitted.
2122 rng: NumPy random generator
2223 X : array-like
2324 A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
@@ -27,8 +28,10 @@ def predict(idata, rng, X, size=None, excluded=None):
2728 excluded : list
2829 indexes of the variables to exclude when computing predictions
2930 """
30- bart_trees = idata .sample_stats .bart_trees
31- stacked_trees = bart_trees .stack (trees = ["chain" , "draw" ])
31+ stacked_trees = bartrv .owner .op .all_trees
32+ if isinstance (X , Variable ):
33+ X = X .eval ()
34+
3235 if size is None :
3336 size = ()
3437 elif isinstance (size , int ):
@@ -38,20 +41,49 @@ def predict(idata, rng, X, size=None, excluded=None):
3841 for s in size :
3942 flatten_size *= s
4043
41- idx = rng .randint (len (stacked_trees . trees ), size = flatten_size )
42- shape = stacked_trees . isel ( trees = 0 ). values [0 ].predict (X [0 ]).size
44+ idx = rng .randint (len (stacked_trees ), size = flatten_size )
45+ shape = stacked_trees [ 0 ] [0 ].predict (X [0 ]).size
4346
4447 pred = np .zeros ((flatten_size , X .shape [0 ], shape ))
4548
4649 for ind , p in enumerate (pred ):
47- for tree in stacked_trees . isel ( trees = idx [ind ]). values :
50+ for tree in stacked_trees [ idx [ind ]] :
4851 p += np .array ([tree .predict (x , excluded ) for x in X ])
4952 pred .reshape ((* size , shape , - 1 ))
5053 return pred
5154
5255
56+ def sample_posterior (all_trees , X ):
57+ """
58+ Generate samples from the BART-posterior.
59+
60+ Parameters
61+ ----------
62+ all_trees : list
63+ List of all trees sampled from a posterior
64+ X : array-like
65+ A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
66+ out-of-sample predictions.
67+ m : int
68+ Number of trees
69+ """
70+ stacked_trees = all_trees
71+ idx = np .random .randint (len (stacked_trees ))
72+ if isinstance (X , Variable ):
73+ X = X .eval ()
74+
75+ shape = stacked_trees [0 ][0 ].predict (X [0 ]).size
76+
77+ pred = np .zeros ((1 , X .shape [0 ], shape ))
78+
79+ for p in pred :
80+ for tree in stacked_trees [idx ]:
81+ p += np .array ([tree .predict (x ) for x in X ])
82+ return pred .squeeze ()
83+
84+
5385def plot_dependence (
54- idata ,
86+ bartrv ,
5587 X ,
5688 Y = None ,
5789 kind = "pdp" ,
@@ -79,8 +111,8 @@ def plot_dependence(
79111
80112 Parameters
81113 ----------
82- idata: InferenceData
83- InferenceData containing a collection of BART_trees in sample_stats group
114+ bartrv : BART Random Variable
115+ BART variable once the model that include it has been fitted.
84116 X : array-like
85117 The covariate matrix.
86118 Y : array-like
@@ -149,6 +181,9 @@ def plot_dependence(
149181
150182 rng = RandomState (seed = random_seed )
151183
184+ if isinstance (X , Variable ):
185+ X = X .eval ()
186+
152187 if hasattr (X , "columns" ) and hasattr (X , "values" ):
153188 x_names = list (X .columns )
154189 X = X .values
@@ -207,13 +242,13 @@ def plot_dependence(
207242 for x_i in new_x_i :
208243 new_X [:, indices_mi ] = X [:, indices_mi ]
209244 new_X [:, i ] = x_i
210- y_pred .append (np .mean (predict (idata , rng , X = new_X , size = samples ), 1 ))
245+ y_pred .append (np .mean (predict (bartrv , rng , X = new_X , size = samples ), 1 ))
211246 new_x_target .append (new_x_i )
212247 else :
213248 for instance in instances :
214249 new_X = X [idx_s ]
215250 new_X [:, indices_mi ] = X [:, indices_mi ][instance ]
216- y_pred .append (np .mean (predict (idata , rng , X = new_X , size = samples ), 0 ))
251+ y_pred .append (np .mean (predict (bartrv , rng , X = new_X , size = samples ), 0 ))
217252 new_x_target .append (new_X [:, i ])
218253 y_mins .append (np .min (y_pred ))
219254 new_y .append (np .array (y_pred ).T )
@@ -310,7 +345,7 @@ def plot_dependence(
310345
311346
312347def plot_variable_importance (
313- idata , X , labels = None , sort_vars = True , figsize = None , samples = 100 , random_seed = None
348+ idata , bartrv , X , labels = None , sort_vars = True , figsize = None , samples = 100 , random_seed = None
314349):
315350 """
316351 Estimates variable importance from the BART-posterior.
@@ -319,6 +354,8 @@ def plot_variable_importance(
319354 ----------
320355 idata: InferenceData
321356 InferenceData containing a collection of BART_trees in sample_stats group
357+ bartrv : BART Random Variable
358+ BART variable once the model that include it has been fitted.
322359 X : array-like
323360 The covariate matrix.
324361 labels : list
@@ -365,12 +402,12 @@ def plot_variable_importance(
365402 axes [0 ].set_xlabel ("covariables" )
366403 axes [0 ].set_ylabel ("importance" )
367404
368- predicted_all = predict (idata , rng , X = X , size = samples , excluded = None )
405+ predicted_all = predict (bartrv , rng , X = X , size = samples , excluded = None )
369406
370407 ev_mean = np .zeros (len (var_imp ))
371408 ev_hdi = np .zeros ((len (var_imp ), 2 ))
372409 for idx , subset in enumerate (subsets ):
373- predicted_subset = predict (idata , rng , X = X , size = samples , excluded = subset )
410+ predicted_subset = predict (bartrv , rng , X = X , size = samples , excluded = subset )
374411 pearson = np .zeros (samples )
375412 for j in range (samples ):
376413 pearson [j ] = (
0 commit comments