@@ -123,6 +123,7 @@ def __init__(
123123 else :
124124 self .batch = (batch , batch )
125125
126+ self .num_particles = num_particles
126127 self .log_num_particles = np .log (num_particles )
127128 self .indices = list (range (2 , num_particles ))
128129 self .len_indices = len (self .indices )
@@ -185,14 +186,10 @@ def astep(self, _):
185186
186187 _ , normalized_weights = self .normalize (particles )
187188 # Get the new tree and update
188- new_particle = np .random .choice (particles , p = normalized_weights )
189- new_tree = new_particle .tree
190-
191- new_particle .log_weight = new_particle .old_likelihood_logp - self .log_num_particles
189+ new_particle , new_tree = self .get_particle_tree (particles , normalized_weights )
192190 self .all_particles [tree_id ] = new_particle
193191 self .sum_trees = self .sum_trees_noi + new_tree ._predict ()
194192 self .all_trees [tree_id ] = new_tree .trim ()
195-
196193 used_variates = new_tree .get_split_variables ()
197194
198195 if self .tune :
@@ -230,7 +227,7 @@ def resample(self, particles, normalized_weights):
230227
231228 Ensure particles are copied only if needed.
232229 """
233- new_indices = systematic (normalized_weights )
230+ new_indices = self . systematic (normalized_weights )
234231 seen = []
235232 new_particles = []
236233 for idx in new_indices :
@@ -244,6 +241,29 @@ def resample(self, particles, normalized_weights):
244241
245242 return particles
246243
244+ def get_particle_tree (self , particles , normalized_weights ):
245+ """
246+ Sample a new particle, new tree and update log_weight
247+ """
248+ new_index = self .systematic (normalized_weights )[
249+ discrete_uniform_sampler (self .num_particles )
250+ ]
251+ new_particle = particles [new_index - 2 ]
252+ new_particle .log_weight = new_particle .old_likelihood_logp - self .log_num_particles
253+ return new_particle , new_particle .tree
254+
255+ def systematic (self , normalized_weights ):
256+ """
257+ Systematic resampling.
258+
259+ Return indices in the range 2, ..., len(normalized_weights)+2
260+
261+ Note: adapted from https://github.com/nchopin/particles
262+ """
263+ lnw = len (normalized_weights )
264+ single_uniform = (self .uniform .random () + np .arange (lnw )) / lnw
265+ return inverse_cdf (single_uniform , normalized_weights ) + 2
266+
247267 def init_particles (self , tree_id : int ) -> np .ndarray :
248268 """Initialize particles."""
249269 p0 = self .all_particles [tree_id ]
@@ -584,19 +604,6 @@ def update(self):
584604 )
585605
586606
587- def systematic (normalized_weights ):
588- """
589- Systematic resampling.
590-
591- Return indices in the range 2, ..., len(normalized_weights)+2
592-
593- Note: adapted from https://github.com/nchopin/particles
594- """
595- lnw = len (normalized_weights )
596- single_uniform = (np .random .rand (1 ) + np .arange (lnw )) / lnw
597- return inverse_cdf (single_uniform , normalized_weights ) + 2
598-
599-
600607@njit
601608def inverse_cdf (single_uniform , normalized_weights ):
602609 """
0 commit comments