@@ -134,9 +134,17 @@ def __init__(
134134 self .missing_data = np .any (np .isnan (self .X ))
135135 self .m = self .bart .m
136136 self .response = self .bart .response
137+
137138 shape = initial_values [value_bart .name ].shape
139+
138140 self .shape = 1 if len (shape ) == 1 else shape [0 ]
139141
142+ # Set trees_shape (dim for separate tree structures)
143+ # and leaves_shape (dim for leaf node values)
144+ # One of the two is always one, the other equal to self.shape
145+ self .trees_shape = self .shape if self .bart .separate_trees else 1
146+ self .leaves_shape = self .shape if not self .bart .separate_trees else 1
147+
140148 if self .bart .split_prior :
141149 self .alpha_vec = self .bart .split_prior
142150 else :
@@ -153,27 +161,31 @@ def __init__(
153161 self .available_predictors = list (range (self .num_variates ))
154162
155163 # if data is binary
164+ self .leaf_sd = np .ones ((self .trees_shape , self .leaves_shape ))
165+
156166 y_unique = np .unique (self .bart .Y )
157167 if y_unique .size == 2 and np .all (y_unique == [0 , 1 ]):
158- self .leaf_sd = 3 / self .m ** 0.5
168+ self .leaf_sd * = 3 / self .m ** 0.5
159169 else :
160- self .leaf_sd = self .bart .Y .std () / self .m ** 0.5
170+ self .leaf_sd * = self .bart .Y .std () / self .m ** 0.5
161171
162- self .running_sd = RunningSd (shape )
172+ self .running_sd = [
173+ RunningSd ((self .leaves_shape , self .num_observations )) for _ in range (self .trees_shape )
174+ ]
163175
164- self .sum_trees = np .full (( self . shape , self . bart . Y . shape [ 0 ]), init_mean ). astype (
165- config . floatX
166- )
176+ self .sum_trees = np .full (
177+ ( self . trees_shape , self . leaves_shape , self . bart . Y . shape [ 0 ]), init_mean
178+ ). astype ( config . floatX )
167179 self .sum_trees_noi = self .sum_trees - init_mean
168180 self .a_tree = Tree .new_tree (
169181 leaf_node_value = init_mean / self .m ,
170182 idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
171183 num_observations = self .num_observations ,
172- shape = self .shape ,
184+ shape = self .leaves_shape ,
173185 split_rules = self .split_rules ,
174186 )
175187
176- self .normal = NormalSampler (1 , self .shape )
188+ self .normal = NormalSampler (1 , self .leaves_shape )
177189 self .uniform = UniformSampler (0 , 1 )
178190 self .prior_prob_leaf_node = compute_prior_probability (self .bart .alpha , self .bart .beta )
179191 self .ssv = SampleSplittingVariable (self .alpha_vec )
@@ -188,8 +200,10 @@ def __init__(
188200 self .indices = list (range (1 , num_particles ))
189201 shared = make_shared_replacements (initial_values , vars , model )
190202 self .likelihood_logp = logp (initial_values , [model .datalogp ], vars , shared )
191- self .all_particles = [ParticleTree (self .a_tree ) for _ in range (self .m )]
192- self .all_trees = np .array ([p .tree for p in self .all_particles ])
203+ self .all_particles = [
204+ [ParticleTree (self .a_tree ) for _ in range (self .m )] for _ in range (self .trees_shape )
205+ ]
206+ self .all_trees = np .array ([[p .tree for p in pl ] for pl in self .all_particles ])
193207 self .lower = 0
194208 self .iter = 0
195209 super ().__init__ (vars , shared )
@@ -201,72 +215,75 @@ def astep(self, _):
201215 tree_ids = range (self .lower , upper )
202216 self .lower = upper if upper < self .m else 0
203217
204- for tree_id in tree_ids :
205- self .iter += 1
206- # Compute the sum of trees without the old tree that we are attempting to replace
207- self .sum_trees_noi = self .sum_trees - self .all_particles [tree_id ].tree ._predict ()
208- # Generate an initial set of particles
209- # at the end we return one of these particles as the new tree
210- particles = self .init_particles (tree_id )
211-
212- while True :
213- # Sample each particle (try to grow each tree), except for the first one
214- stop_growing = True
215- for p in particles [1 :]:
216- if p .sample_tree (
217- self .ssv ,
218- self .available_predictors ,
219- self .prior_prob_leaf_node ,
220- self .X ,
221- self .missing_data ,
222- self .sum_trees ,
223- self .leaf_sd ,
224- self .m ,
225- self .response ,
226- self .normal ,
227- self .shape ,
228- ):
229- self .update_weight (p )
230- if p .expansion_nodes :
231- stop_growing = False
232- if stop_growing :
233- break
234-
235- # Normalize weights
236- normalized_weights = self .normalize (particles [1 :])
237-
238- # Resample
239- particles = self .resample (particles , normalized_weights )
240-
241- normalized_weights = self .normalize (particles )
242- # Get the new particle and associated tree
243- self .all_particles [tree_id ], new_tree = self .get_particle_tree (
244- particles , normalized_weights
245- )
246- # Update the sum of trees
247- new = new_tree ._predict ()
248- self .sum_trees = self .sum_trees_noi + new
249- # To reduce memory usage, we trim the tree
250- self .all_trees [tree_id ] = new_tree .trim ()
251-
252- if self .tune :
253- # Update the splitting variable and the splitting variable sampler
254- if self .iter > self .m :
255- self .ssv = SampleSplittingVariable (self .alpha_vec )
256-
257- for index in new_tree .get_split_variables ():
258- self .alpha_vec [index ] += 1
259-
260- # update standard deviation at leaf nodes
261- if self .iter > 2 :
262- self .leaf_sd = self .running_sd .update (new )
263- else :
264- self .running_sd .update (new )
218+ for odim in range (self .trees_shape ):
219+ for tree_id in tree_ids :
220+ self .iter += 1
221+ # Compute the sum of trees without the old tree that we are attempting to replace
222+ self .sum_trees_noi [odim ] = (
223+ self .sum_trees [odim ] - self .all_particles [odim ][tree_id ].tree ._predict ()
224+ )
225+ # Generate an initial set of particles
226+ # at the end we return one of these particles as the new tree
227+ particles = self .init_particles (tree_id , odim )
228+
229+ while True :
230+ # Sample each particle (try to grow each tree), except for the first one
231+ stop_growing = True
232+ for p in particles [1 :]:
233+ if p .sample_tree (
234+ self .ssv ,
235+ self .available_predictors ,
236+ self .prior_prob_leaf_node ,
237+ self .X ,
238+ self .missing_data ,
239+ self .sum_trees [odim ],
240+ self .leaf_sd [odim ],
241+ self .m ,
242+ self .response ,
243+ self .normal ,
244+ self .leaves_shape ,
245+ ):
246+ self .update_weight (p , odim )
247+ if p .expansion_nodes :
248+ stop_growing = False
249+ if stop_growing :
250+ break
251+
252+ # Normalize weights
253+ normalized_weights = self .normalize (particles [1 :])
254+
255+ # Resample
256+ particles = self .resample (particles , normalized_weights )
257+
258+ normalized_weights = self .normalize (particles )
259+ # Get the new particle and associated tree
260+ self .all_particles [odim ][tree_id ], new_tree = self .get_particle_tree (
261+ particles , normalized_weights
262+ )
263+ # Update the sum of trees
264+ new = new_tree ._predict ()
265+ self .sum_trees [odim ] = self .sum_trees_noi [odim ] + new
266+ # To reduce memory usage, we trim the tree
267+ self .all_trees [odim ][tree_id ] = new_tree .trim ()
268+
269+ if self .tune :
270+ # Update the splitting variable and the splitting variable sampler
271+ if self .iter > self .m :
272+ self .ssv = SampleSplittingVariable (self .alpha_vec )
273+
274+ for index in new_tree .get_split_variables ():
275+ self .alpha_vec [index ] += 1
276+
277+ # update standard deviation at leaf nodes
278+ if self .iter > 2 :
279+ self .leaf_sd [odim ] = self .running_sd [odim ].update (new )
280+ else :
281+ self .running_sd [odim ].update (new )
265282
266- else :
267- # update the variable inclusion
268- for index in new_tree .get_split_variables ():
269- variable_inclusion [index ] += 1
283+ else :
284+ # update the variable inclusion
285+ for index in new_tree .get_split_variables ():
286+ variable_inclusion [index ] += 1
270287
271288 if not self .tune :
272289 self .bart .all_trees .append (self .all_trees )
@@ -331,23 +348,27 @@ def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[
331348 single_uniform = (self .uniform .rvs () + np .arange (lnw )) / lnw
332349 return inverse_cdf (single_uniform , normalized_weights )
333350
334- def init_particles (self , tree_id : int ) -> List [ParticleTree ]:
351+ def init_particles (self , tree_id : int , odim : int ) -> List [ParticleTree ]:
335352 """Initialize particles."""
336- p0 : ParticleTree = self .all_particles [tree_id ]
353+ p0 : ParticleTree = self .all_particles [odim ][ tree_id ]
337354 # The old tree does not grow so we update the weight only once
338- self .update_weight (p0 )
355+ self .update_weight (p0 , odim )
339356 particles : List [ParticleTree ] = [p0 ]
340357
341358 particles .extend (ParticleTree (self .a_tree ) for _ in self .indices )
342359 return particles
343360
344- def update_weight (self , particle : ParticleTree ) -> None :
361+ def update_weight (self , particle : ParticleTree , odim : int ) -> None :
345362 """
346363 Update the weight of a particle.
347364 """
348- new_likelihood = self .likelihood_logp (
349- (self .sum_trees_noi + particle .tree ._predict ()).flatten ()
365+
366+ delta = (
367+ np .identity (self .trees_shape )[odim ][:, None , None ]
368+ * particle .tree ._predict ()[None , :, :]
350369 )
370+
371+ new_likelihood = self .likelihood_logp ((self .sum_trees_noi + delta ).flatten ())
351372 particle .log_weight = new_likelihood
352373
353374 @staticmethod
0 commit comments