@@ -65,10 +65,6 @@ def new_leaf_node(
6565 linear_params = linear_params ,
6666 )
6767
68- @classmethod
69- def new_split_node (cls , split_value : npt .NDArray [np .float_ ], idx_split_variable : int ) -> "Node" :
70- return cls (value = split_value , idx_split_variable = idx_split_variable )
71-
7268 def is_split_node (self ) -> bool :
7369 return self .idx_split_variable >= 0
7470
@@ -282,42 +278,42 @@ def _traverse_tree(
282278 """
283279
284280 x_shape = (1 ,) if len (X .shape ) == 1 else X .shape [:- 1 ]
281+ nd_dims = (...,) + (None ,) * len (x_shape )
285282
286- stack = [(0 , np .ones (x_shape ))] # (node_index, weight) initial state
283+ stack = [(0 , np .ones (x_shape ), 0 )] # (node_index, weight, idx_split_variable ) initial state
287284 p_d = (
288285 np .zeros (shape + x_shape ) if isinstance (shape , tuple ) else np .zeros ((shape ,) + x_shape )
289286 )
290287 while stack :
291- node_index , weights = stack .pop ()
288+ node_index , weights , idx_split_variable = stack .pop ()
292289 node = self .get_node (node_index )
293290 if node .is_leaf_node ():
294291 params = node .linear_params
295- nd_dims = (...,) + (None ,) * len (x_shape )
296292 if params is None :
297293 p_d += weights * node .value [nd_dims ]
298294 else :
299- # this produce nonsensical results
300295 p_d += weights * (
301- params [0 ][nd_dims ] + params [1 ][nd_dims ] * X [..., node . idx_split_variable ]
296+ params [0 ][nd_dims ] + params [1 ][nd_dims ] * X [..., idx_split_variable ]
302297 )
303- # this produce reasonable result
304- # p_d += weight * node.value.mean()
305298 else :
306299 left_node_index , right_node_index = get_idx_left_child (
307300 node_index
308301 ), get_idx_right_child (node_index )
302+ idx_split_variable = node .idx_split_variable
309303 if excluded is not None and node .idx_split_variable in excluded :
310304 prop_nvalue_left = self .get_node (left_node_index ).nvalue / node .nvalue
311- stack .append ((left_node_index , weights * prop_nvalue_left ))
312- stack .append ((right_node_index , weights * (1 - prop_nvalue_left )))
305+ stack .append ((left_node_index , weights * prop_nvalue_left , idx_split_variable ))
306+ stack .append (
307+ (right_node_index , weights * (1 - prop_nvalue_left ), idx_split_variable )
308+ )
313309 else :
314310 to_left = (
315- self .split_rules [node . idx_split_variable ]
316- .divide (X [..., node . idx_split_variable ], node .value )
311+ self .split_rules [idx_split_variable ]
312+ .divide (X [..., idx_split_variable ], node .value )
317313 .astype ("float" )
318314 )
319- stack .append ((left_node_index , weights * to_left ))
320- stack .append ((right_node_index , weights * (1 - to_left )))
315+ stack .append ((left_node_index , weights * to_left , idx_split_variable ))
316+ stack .append ((right_node_index , weights * (1 - to_left ), idx_split_variable ))
321317
322318 if len (X .shape ) == 1 :
323319 p_d = p_d [..., 0 ]
0 commit comments