@@ -138,6 +138,11 @@ def __init__( # noqa: PLR0915
138138 else :
139139 self .X = self .bart .X
140140
141+ if isinstance (self .bart .Y , Variable ):
142+ self .Y = self .bart .Y .eval ()
143+ else :
144+ self .Y = self .bart .Y
145+
141146 self .missing_data = np .any (np .isnan (self .X ))
142147 self .m = self .bart .m
143148 self .response = self .bart .response
@@ -166,26 +171,26 @@ def __init__( # noqa: PLR0915
166171 if rule is ContinuousSplitRule :
167172 self .X [:, idx ] = jitter_duplicated (self .X [:, idx ], np .nanstd (self .X [:, idx ]))
168173
169- init_mean = self .bart . Y .mean ()
174+ init_mean = self .Y .mean ()
170175 self .num_observations = self .X .shape [0 ]
171176 self .num_variates = self .X .shape [1 ]
172177 self .available_predictors = list (range (self .num_variates ))
173178
174179 # if data is binary
175180 self .leaf_sd = np .ones ((self .trees_shape , self .leaves_shape ))
176181
177- y_unique = np .unique (self .bart . Y )
182+ y_unique = np .unique (self .Y )
178183 if y_unique .size == 2 and np .all (y_unique == [0 , 1 ]):
179184 self .leaf_sd *= 3 / self .m ** 0.5
180185 else :
181- self .leaf_sd *= self .bart . Y .std () / self .m ** 0.5
186+ self .leaf_sd *= self .Y .std () / self .m ** 0.5
182187
183188 self .running_sd = [
184189 RunningSd ((self .leaves_shape , self .num_observations )) for _ in range (self .trees_shape )
185190 ]
186191
187192 self .sum_trees = np .full (
188- (self .trees_shape , self .leaves_shape , self .bart . Y .shape [0 ]), init_mean
193+ (self .trees_shape , self .leaves_shape , self .Y .shape [0 ]), init_mean
189194 ).astype (config .floatX )
190195 self .sum_trees_noi = self .sum_trees - init_mean
191196 self .a_tree = Tree .new_tree (
0 commit comments