@@ -408,14 +408,14 @@ def plot_pdp(
408408 fig , axes , shape = _get_axes (bartrv , var_idx , grid , sharey , figsize , ax )
409409
410410 count = 0
411+ fake_X = _create_pdp_data (X , xs_interval , xs_values )
411412 for var in range (len (var_idx )):
412413 excluded = indices [:]
413414 excluded .remove (var )
414- fake_X , new_x = _create_pdp_data (X , xs_interval , var , xs_values , var_discrete )
415415 p_d = _sample_posterior (
416416 all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
417417 )
418-
418+ new_x = fake_X [:, var ]
419419 for s_i in range (shape ):
420420 p_di = func (p_d [:, :, s_i ])
421421 if var in var_discrete :
@@ -621,10 +621,8 @@ def _prepare_plot_data(
621621def _create_pdp_data (
622622 X : npt .NDArray [np .float_ ],
623623 xs_interval : str ,
624- var : int ,
625624 xs_values : Optional [Union [int , List [float ]]] = None ,
626- var_discrete : Optional [List [int ]] = None ,
627- ) -> Tuple [npt .NDArray [np .float_ ], npt .NDArray [np .float_ ]]:
625+ ) -> npt .NDArray [np .float_ ]:
628626 """
629627 Create data for partial dependence plot.
630628
@@ -636,28 +634,23 @@ def _create_pdp_data(
636634 Interval for x-axis. Available options are 'insample', 'linear' or 'quantiles'.
637635 xs_values : int or list
638636 Number of points for 'linear' or list of quantiles for 'quantiles'.
639- var : int
640- Index of variable of interest
641- var_discrete : None or list
642- Indices of discrete variables.
643637
644638 Returns
645639 -------
646- Tuple[ npt.NDArray[np.float_], npt.NDArray[np.float_] ]
647- A tuple containing a 2D array for the fake_X data and 1D array for new_x data.
640+ npt.NDArray[np.float_]
641+ A 2D array for the fake_X data.
648642 """
649643 if xs_interval == "insample" :
650- return X , X [:, var ]
644+ return X
651645 else :
652- if var_discrete is not None and var in var_discrete :
653- new_x = np .unique (X [:, var ])
654- else :
655- if xs_interval == "linear" and isinstance (xs_values , int ):
656- new_x = np .linspace (np .nanmin (X [:, var ]), np .nanmax (X [:, var ]), xs_values )
657- elif xs_interval == "quantiles" and isinstance (xs_values , list ):
658- new_x = np .quantile (X [:, var ], q = xs_values )
659-
660- return np .tile (new_x [:, None ], X .shape [1 ]), new_x
646+ if xs_interval == "linear" and isinstance (xs_values , int ):
647+ min_vals = np .min (X , axis = 0 )
648+ max_vals = np .max (X , axis = 0 )
649+ fake_X = np .linspace (min_vals , max_vals , num = xs_values , axis = 0 )
650+ elif xs_interval == "quantiles" and isinstance (xs_values , list ):
651+ fake_X = np .quantile (X , q = xs_values , axis = 0 )
652+
653+ return fake_X
661654
662655
663656def _smooth_mean (
0 commit comments