55import numpy as np
66
77from aesara .tensor .var import Variable
8- from numpy .random import RandomState
98from scipy .interpolate import griddata
109from scipy .signal import savgol_filter
1110from scipy .stats import pearsonr
1211
1312
14- def predict ( bartrv , rng , X , size = None , excluded = None ):
13+ def _sample_posterior ( all_trees , X , rng , size = None , excluded = None ):
1514 """
1615 Generate samples from the BART-posterior.
1716
1817 Parameters
1918 ----------
20- bartrv : BART Random Variable
21- BART variable once the model that include it has been fitted.
22- rng: NumPy random generator
19+ all_trees : list
20+ List of all trees sampled from a posterior
2321 X : array-like
2422 A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
2523 out-of-sample predictions.
24+ rng : NumPy RandomGenerator
2625 size : int or tuple
2726 Number of samples.
2827 excluded : list
29- indexes of the variables to exclude when computing predictions
28+ Indexes of the variables to exclude when computing predictions
3029 """
31- stacked_trees = bartrv . owner . op . all_trees
30+ stacked_trees = all_trees
3231 if isinstance (X , Variable ):
3332 X = X .eval ()
3433
@@ -41,7 +40,7 @@ def predict(bartrv, rng, X, size=None, excluded=None):
4140 for s in size :
4241 flatten_size *= s
4342
44- idx = rng .randint ( len (stacked_trees ), size = flatten_size )
43+ idx = rng .integers ( 0 , len (stacked_trees ), size = flatten_size )
4544 shape = stacked_trees [0 ][0 ].predict (X [0 ]).size
4645
4746 pred = np .zeros ((flatten_size , X .shape [0 ], shape ))
@@ -53,35 +52,6 @@ def predict(bartrv, rng, X, size=None, excluded=None):
5352 return pred
5453
5554
56- def sample_posterior (all_trees , X ):
57- """
58- Generate samples from the BART-posterior.
59-
60- Parameters
61- ----------
62- all_trees : list
63- List of all trees sampled from a posterior
64- X : array-like
65- A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
66- out-of-sample predictions.
67- m : int
68- Number of trees
69- """
70- stacked_trees = all_trees
71- idx = np .random .randint (len (stacked_trees ))
72- if isinstance (X , Variable ):
73- X = X .eval ()
74-
75- shape = stacked_trees [0 ][0 ].predict (X [0 ]).size
76-
77- pred = np .zeros ((1 , X .shape [0 ], shape ))
78-
79- for p in pred :
80- for tree in stacked_trees [idx ]:
81- p += np .array ([tree .predict (x ) for x in X ])
82- return pred .squeeze ()
83-
84-
8555def plot_dependence (
8656 bartrv ,
8757 X ,
@@ -179,8 +149,6 @@ def plot_dependence(
179149 Available option are 'insample', 'linear' or 'quantiles'"""
180150 )
181151
182- rng = RandomState (seed = random_seed )
183-
184152 if isinstance (X , Variable ):
185153 X = X .eval ()
186154
@@ -195,6 +163,8 @@ def plot_dependence(
195163 else :
196164 y_label = "Predicted Y"
197165
166+ rng = np .random .default_rng (random_seed )
167+
198168 num_covariates = X .shape [1 ]
199169
200170 indices = list (range (num_covariates ))
@@ -216,14 +186,15 @@ def plot_dependence(
216186 xs_values = [0.05 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.95 ]
217187
218188 if kind == "ice" :
219- instances = np . random .choice (range (X .shape [0 ]), replace = False , size = instances )
189+ instances = rng .choice (range (X .shape [0 ]), replace = False , size = instances )
220190
221191 new_y = []
222192 new_x_target = []
223193 y_mins = []
224194
225195 new_X = np .zeros_like (X )
226196 idx_s = list (range (X .shape [0 ]))
197+ all_trees = bartrv .owner .op .all_trees
227198 for i in var_idx :
228199 indices_mi = indices [:]
229200 indices_mi .pop (i )
@@ -242,13 +213,17 @@ def plot_dependence(
242213 for x_i in new_x_i :
243214 new_X [:, indices_mi ] = X [:, indices_mi ]
244215 new_X [:, i ] = x_i
245- y_pred .append (np .mean (predict (bartrv , rng , X = new_X , size = samples ), 1 ))
216+ y_pred .append (
217+ np .mean (_sample_posterior (all_trees , X = new_X , rng = rng , size = samples ), 1 )
218+ )
246219 new_x_target .append (new_x_i )
247220 else :
248221 for instance in instances :
249222 new_X = X [idx_s ]
250223 new_X [:, indices_mi ] = X [:, indices_mi ][instance ]
251- y_pred .append (np .mean (predict (bartrv , rng , X = new_X , size = samples ), 0 ))
224+ y_pred .append (
225+ np .mean (_sample_posterior (all_trees , X = new_X , rng = rng , size = samples ), 0 )
226+ )
252227 new_x_target .append (new_X [:, i ])
253228 y_mins .append (np .min (y_pred ))
254229 new_y .append (np .array (y_pred ).T )
@@ -328,7 +303,7 @@ def plot_dependence(
328303 nxi ,
329304 nyi ,
330305 smooth = smooth ,
331- fill_kwargs = {"alpha" : alpha },
306+ fill_kwargs = {"alpha" : alpha , "color" : color },
332307 ax = ax ,
333308 )
334309 ax .plot (nxi [idx ], nyi [idx ].mean (0 ), color = color )
@@ -374,7 +349,6 @@ def plot_variable_importance(
374349 idxs: indexes of the covariates from higher to lower relative importance
375350 axes: matplotlib axes
376351 """
377- rng = RandomState (seed = random_seed )
378352 _ , axes = plt .subplots (2 , 1 , figsize = figsize )
379353
380354 if hasattr (X , "columns" ) and hasattr (X , "values" ):
@@ -387,6 +361,8 @@ def plot_variable_importance(
387361 else :
388362 labels = np .array (labels )
389363
364+ rng = np .random .default_rng (random_seed )
365+
390366 ticks = np .arange (len (var_imp ), dtype = int )
391367 idxs = np .argsort (var_imp )
392368 subsets = [idxs [:- i ] for i in range (1 , len (idxs ))]
@@ -402,12 +378,14 @@ def plot_variable_importance(
402378 axes [0 ].set_xlabel ("covariables" )
403379 axes [0 ].set_ylabel ("importance" )
404380
405- predicted_all = predict (bartrv , rng , X = X , size = samples , excluded = None )
381+ all_trees = bartrv .owner .op .all_trees
382+
383+ predicted_all = _sample_posterior (all_trees , X = X , rng = rng , size = samples , excluded = None )
406384
407385 ev_mean = np .zeros (len (var_imp ))
408386 ev_hdi = np .zeros ((len (var_imp ), 2 ))
409387 for idx , subset in enumerate (subsets ):
410- predicted_subset = predict ( bartrv , rng , X = X , size = samples , excluded = subset )
388+ predicted_subset = _sample_posterior ( all_trees , X = X , rng = rng , size = samples , excluded = subset )
411389 pearson = np .zeros (samples )
412390 for j in range (samples ):
413391 pearson [j ] = (
0 commit comments