1414
1515import logging
1616
17- from copy import copy
17+ from copy import deepcopy
18+ from numba import njit
1819
1920import aesara
2021import numpy as np
@@ -56,7 +57,7 @@ class PGBART(ArrayStepShared):
5657 def __init__ (
5758 self ,
5859 vars = None ,
59- num_particles = 40 ,
60+ num_particles = 20 ,
6061 batch = "auto" ,
6162 model = None ,
6263 ):
@@ -104,8 +105,6 @@ def __init__(
104105 idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
105106 shape = self .shape ,
106107 )
107- self .mean = fast_mean ()
108-
109108 self .normal = NormalSampler (mu_std , self .shape )
110109 self .uniform = UniformSampler (0.33 , 0.75 , self .shape )
111110 self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
@@ -158,7 +157,6 @@ def astep(self, _):
158157 self .X ,
159158 self .missing_data ,
160159 self .sum_trees ,
161- self .mean ,
162160 self .m ,
163161 self .normal ,
164162 self .shape ,
@@ -173,11 +171,8 @@ def astep(self, _):
173171 # Normalize weights
174172 w_t , normalized_weights = self .normalize (particles [2 :])
175173
176- # Resample all but first two particles
177- new_indices = np .random .choice (
178- self .indices , size = self .len_indices , p = normalized_weights
179- )
180- particles [2 :] = particles [new_indices ]
174+ # Resample
175+ particles = self .resample (particles , normalized_weights )
181176
182177 # Set the new weight
183178 for p in particles [2 :]:
@@ -196,15 +191,17 @@ def astep(self, _):
196191 self .sum_trees = self .sum_trees_noi + new_tree ._predict ()
197192 self .all_trees [tree_id ] = new_tree .trim ()
198193
194+ used_variates = new_tree .get_split_variables ()
195+
199196 if self .tune :
200197 self .ssv = SampleSplittingVariable (self .alpha_vec )
201- for index in new_particle . used_variates :
198+ for index in used_variates :
202199 self .alpha_vec [index ] += 1
203200 else :
204- for index in new_particle . used_variates :
201+ for index in used_variates :
205202 variable_inclusion [index ] += 1
206203
207- stats = {"variable_inclusion" : variable_inclusion , "bart_trees" : copy ( self .all_trees ) }
204+ stats = {"variable_inclusion" : variable_inclusion , "bart_trees" : self .all_trees }
208205 return self .sum_trees , [stats ]
209206
210207 def normalize (self , particles ):
@@ -225,18 +222,36 @@ def normalize(self, particles):
225222
226223 return w_t , normalized_weights
227224
225+ def resample (self , particles , normalized_weights ):
226+ """
227+ Use systematic resample for all but first two particles
228+
229+ Ensure particles are copied only if needed.
230+ """
231+ new_indices = systematic (normalized_weights )
232+ seen = []
233+ new_particles = []
234+ for idx in new_indices :
235+ if idx in seen :
236+ new_particles .append (deepcopy (particles [idx ]))
237+ else :
238+ new_particles .append (particles [idx ])
239+ seen .append (idx )
240+
241+ particles [2 :] = new_particles
242+
243+ return particles
244+
228245 def init_particles (self , tree_id : int ) -> np .ndarray :
229246 """Initialize particles."""
230247 p0 = self .all_particles [tree_id ]
231- p1 = copy (p0 )
248+ p1 = deepcopy (p0 )
232249 p1 .sample_leafs (
233250 self .sum_trees ,
234- self .mean ,
235251 self .m ,
236252 self .normal ,
237253 self .shape ,
238254 )
239-
240255 # The old tree and the one with new leafs do not grow so we update the weights only once
241256 self .update_weight (p0 , old = True )
242257 self .update_weight (p1 , old = True )
@@ -286,7 +301,6 @@ def __init__(self, tree):
286301 self .expansion_nodes = [0 ]
287302 self .log_weight = 0
288303 self .old_likelihood_logp = 0
289- self .used_variates = []
290304 self .kf = 0.75
291305
292306 def sample_tree (
@@ -297,7 +311,6 @@ def sample_tree(
297311 X ,
298312 missing_data ,
299313 sum_trees ,
300- mean ,
301314 m ,
302315 normal ,
303316 shape ,
@@ -317,7 +330,6 @@ def sample_tree(
317330 X ,
318331 missing_data ,
319332 sum_trees ,
320- mean ,
321333 m ,
322334 normal ,
323335 self .kf ,
@@ -326,20 +338,18 @@ def sample_tree(
326338 if index_selected_predictor is not None :
327339 new_indexes = self .tree .idx_leaf_nodes [- 2 :]
328340 self .expansion_nodes .extend (new_indexes )
329- self .used_variates .append (index_selected_predictor )
330341 tree_grew = True
331342
332343 return tree_grew
333344
334- def sample_leafs (self , sum_trees , mean , m , normal , shape ):
345+ def sample_leafs (self , sum_trees , m , normal , shape ):
335346
336347 for idx in self .tree .idx_leaf_nodes :
337348 if idx > 0 :
338349 leaf = self .tree [idx ]
339350 idx_data_points = leaf .idx_data_points
340351 node_value = draw_leaf_value (
341352 sum_trees [:, idx_data_points ],
342- mean ,
343353 m ,
344354 normal ,
345355 self .kf ,
@@ -400,7 +410,6 @@ def grow_tree(
400410 X ,
401411 missing_data ,
402412 sum_trees ,
403- mean ,
404413 m ,
405414 normal ,
406415 kf ,
@@ -429,7 +438,6 @@ def grow_tree(
429438 idx_data_point = new_idx_data_points [idx ]
430439 node_value = draw_leaf_value (
431440 sum_trees [:, idx_data_point ],
432- mean ,
433441 m ,
434442 normal ,
435443 kf ,
@@ -482,7 +490,7 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data):
482490 return split_value
483491
484492
485- def draw_leaf_value (Y_mu_pred , mean , m , normal , kf , shape ):
493+ def draw_leaf_value (Y_mu_pred , m , normal , kf , shape ):
486494 """Draw Gaussian distributed leaf values."""
487495 if Y_mu_pred .size == 0 :
488496 return np .zeros (shape )
@@ -491,38 +499,29 @@ def draw_leaf_value(Y_mu_pred, mean, m, normal, kf, shape):
491499 if Y_mu_pred .size == 1 :
492500 mu_mean = np .full (shape , Y_mu_pred .item () / m )
493501 else :
494- mu_mean = mean (Y_mu_pred ) / m
502+ mu_mean = fast_mean (Y_mu_pred ) / m
495503
496504 draw = norm + mu_mean
497505 return draw
498506
499507
500- def fast_mean ():
501- """If available use Numba to speed up the computation of the mean."""
502- try :
503- from numba import jit
504- except ImportError :
505- from functools import partial
506-
507- return partial ( np . mean , axis = 1 )
508-
509- @ jit
510- def mean ( a ) :
511- if a . ndim == 1 :
512- count = a .shape [0 ]
513- suma = 0
508+ @ njit
509+ def fast_mean ( a ):
510+ """Use Numba to speed up the computation of the mean."""
511+
512+ if a . ndim == 1 :
513+ count = a . shape [ 0 ]
514+ suma = 0
515+ for i in range ( count ):
516+ suma += a [ i ]
517+ return suma / count
518+ elif a . ndim == 2 :
519+ res = np . zeros ( a . shape [ 0 ])
520+ count = a .shape [1 ]
521+ for j in range ( a . shape [ 0 ]):
514522 for i in range (count ):
515- suma += a [i ]
516- return suma / count
517- elif a .ndim == 2 :
518- res = np .zeros (a .shape [0 ])
519- count = a .shape [1 ]
520- for j in range (a .shape [0 ]):
521- for i in range (count ):
522- res [j ] += a [j , i ]
523- return res / count
524-
525- return mean
523+ res [j ] += a [j , i ]
524+ return res / count
526525
527526
528527def discrete_uniform_sampler (upper_value ):
@@ -578,6 +577,51 @@ def update(self):
578577 )
579578
580579
580+ def systematic (normalized_weights ):
581+ """
582+ Systematic resampling.
583+
584+ Return indices in the range 2, ..., len(normalized_weights)+2
585+
586+ Note: adapted from https://github.com/nchopin/particles
587+ """
588+ lnw = len (normalized_weights )
589+ single_uniform = (np .random .rand (1 ) + np .arange (lnw )) / lnw
590+ return inverse_cdf (single_uniform , normalized_weights ) + 2
591+
592+
593+ @njit
594+ def inverse_cdf (single_uniform , normalized_weights ):
595+ """
596+ Inverse CDF algorithm for a finite distribution.
597+
598+ Parameters
599+ ----------
600+ single_uniform: ndarray
601+ ordered points in [0,1]
602+
603+ normalized_weights: ndarray
604+ normalized weights
605+
606+ Returns
607+ -------
608+ A: ndarray
609+ a vector of indices in range 2, ..., len(normalized_weights)+2
610+
611+ Note: adapted from https://github.com/nchopin/particles
612+ """
613+ j = 0
614+ s = normalized_weights [0 ]
615+ M = single_uniform .shape [0 ]
616+ A = np .empty (M , dtype = np .int64 )
617+ for n in range (M ):
618+ while single_uniform [n ] > s :
619+ j += 1
620+ s += normalized_weights [j ]
621+ A [n ] = j
622+ return A
623+
624+
581625def logp (point , out_vars , vars , shared ):
582626 """Compile Aesara function of the model and the input and output variables.
583627
0 commit comments