@@ -80,47 +80,43 @@ def __init__(
8080 else :
8181 self .X = self .bart .X
8282
83- self .Y = self .bart .Y
8483 self .missing_data = np .any (np .isnan (self .X ))
8584 self .m = self .bart .m
86- self .alpha = self .bart .alpha
8785 shape = initial_values [value_bart .name ].shape
8886 if len (shape ) == 1 :
8987 self .shape = 1
9088 else :
9189 self .shape = shape [0 ]
9290
93- # self.alpha_vec = self.bart.split_prior
9491 if self .bart .split_prior :
9592 self .alpha_vec = self .bart .split_prior
9693 else :
9794 self .alpha_vec = np .ones (self .X .shape [1 ])
98- self . init_mean = self .Y .mean ()
95+ init_mean = self . bart .Y .mean ()
9996 # if data is binary
100- y_unique = np .unique (self .Y )
97+ y_unique = np .unique (self .bart . Y )
10198 if y_unique .size == 2 and np .all (y_unique == [0 , 1 ]):
10299 mu_std = 3 / self .m ** 0.5
103- # maybe we need to check for count data
104100 else :
105- mu_std = self .Y .std () / self .m ** 0.5
101+ mu_std = self .bart . Y .std () / self .m ** 0.5
106102
107103 self .num_observations = self .X .shape [0 ]
108104 self .num_variates = self .X .shape [1 ]
109105 self .available_predictors = list (range (self .num_variates ))
110106
111- self .sum_trees = np .full ((self .shape , self .Y .shape [0 ]), self . init_mean ).astype (
107+ self .sum_trees = np .full ((self .shape , self .bart . Y .shape [0 ]), init_mean ).astype (
112108 config .floatX
113109 )
114- self .sum_trees_noi = self .sum_trees - (self . init_mean / self .m )
110+ self .sum_trees_noi = self .sum_trees - (init_mean / self .m )
115111 self .a_tree = Tree (
116- leaf_node_value = self . init_mean / self .m ,
112+ leaf_node_value = init_mean / self .m ,
117113 idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
118114 num_observations = self .num_observations ,
119115 shape = self .shape ,
120116 )
121117 self .normal = NormalSampler (mu_std , self .shape )
122118 self .uniform = UniformSampler (0.33 , 0.75 , self .shape )
123- self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
119+ self .prior_prob_leaf_node = compute_prior_probability (self .bart . alpha )
124120 self .ssv = SampleSplittingVariable (self .alpha_vec )
125121
126122 self .tune = True
@@ -143,7 +139,7 @@ def __init__(
143139 self .likelihood_logp = logp (initial_values , [model .datalogp ], vars , shared )
144140 self .all_particles = []
145141 for _ in range (self .m ):
146- self .a_tree .leaf_node_value = self . init_mean / self .m
142+ self .a_tree .leaf_node_value = init_mean / self .m
147143 p = ParticleTree (self .a_tree )
148144 self .all_particles .append (p )
149145 self .all_trees = np .array ([p .tree for p in self .all_particles ])
0 commit comments