@@ -155,14 +155,12 @@ def plot_ice(
155155 bartrv : Variable ,
156156 X : npt .NDArray [np .float_ ],
157157 Y : Optional [npt .NDArray [np .float_ ]] = None ,
158- xs_interval : str = "quantiles" ,
159- xs_values : Optional [Union [int , List [float ]]] = None ,
160158 var_idx : Optional [List [int ]] = None ,
161159 var_discrete : Optional [List [int ]] = None ,
162160 func : Optional [Callable ] = None ,
163161 centered : Optional [bool ] = True ,
164- samples : int = 50 ,
165- instances : int = 10 ,
162+ samples : int = 100 ,
163+ instances : int = 30 ,
166164 random_seed : Optional [int ] = None ,
167165 sharey : bool = True ,
168166 smooth : bool = True ,
@@ -185,16 +183,6 @@ def plot_ice(
185183 The covariate matrix.
186184 Y : Optional[npt.NDArray[np.float_]], by default None.
187185 The response vector.
188- xs_interval : str
189- Method used to compute the values X used to evaluate the predicted function. "linear",
190- evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified
191- quantiles of X. "insample", the evaluation is done at the values of X.
192- For discrete variables these options are ommited.
193- xs_values : Optional[Union[int, List[float]]], by default None.
194- Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
195- points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
196- quantiles to compute, which must be between 0 and 1 inclusive.
197- Ignored when ``xs_interval="insample"``.
198186 var_idx : Optional[List[int]], by default None.
199187 List of the indices of the covariate for which to compute the pdp or ice.
200188 var_discrete : Optional[List[int]], by default None.
@@ -205,22 +193,20 @@ def plot_ice(
205193 If True the result is centered around the partial response evaluated at the lowest value in
206194 ``xs_interval``. Defaults to True.
207195 samples : int
208- Number of posterior samples used in the predictions. Defaults to 50
196+ Number of posterior samples used in the predictions. Defaults to 100
209197 instances : int
210- Number of instances of X to plot. Defaults to 10 .
198+ Number of instances of X to plot. Defaults to 30 .
211199 random_seed : Optional[int], by default None.
212200 Seed used to sample from the posterior. Defaults to None.
213201 sharey : bool
214202 Controls sharing of properties among y-axes. Defaults to True.
215- rug : bool
216- Whether to include a rugplot. Defaults to True.
217203 smooth : bool
218204 If True the result will be smoothed by first computing a linear interpolation of the data
219205 over a regular grid and then applying the Savitzky-Golay filter to the interpolated data.
220206 Defaults to True.
221207 grid : str or tuple
222208 How to arrange the subplots. Defaults to "long", one subplot below the other.
223- Other options are "wide", one subplot next to eachother or a tuple indicating the number of
209+ Other options are "wide", one subplot next to each other or a tuple indicating the number of
224210 rows and columns.
225211 color : matplotlib valid color
226212 Color used to plot the pdp or ice. Defaults to "C0"
@@ -257,17 +243,17 @@ def identity(x):
257243 indices ,
258244 var_idx ,
259245 var_discrete ,
260- xs_interval ,
261- xs_values ,
262- ) = _prepare_plot_data (X , Y , xs_interval , xs_values , var_idx , var_discrete )
246+ _ ,
247+ _ ,
248+ ) = _prepare_plot_data (X , Y , "linear" , None , var_idx , var_discrete )
263249
264250 fig , axes , shape = _get_axes (bartrv , var_idx , grid , sharey , figsize , ax )
265251
266252 instances_ary = rng .choice (range (X .shape [0 ]), replace = False , size = instances )
267253 idx_s = list (range (X .shape [0 ]))
268254
269255 count = 0
270- for var in range ( len ( var_idx ) ):
256+ for i_var , var in enumerate ( var_idx ):
271257 indices_mi = indices [:]
272258 indices_mi .remove (var )
273259 y_pred = []
@@ -283,6 +269,7 @@ def identity(x):
283269
284270 new_x = fake_X [:, var ]
285271 p_d = np .array (y_pred )
272+ print (p_d .shape )
286273
287274 for s_i in range (shape ):
288275 if centered :
@@ -301,7 +288,7 @@ def identity(x):
301288 idx = np .argsort (new_x )
302289 axes [count ].plot (new_x [idx ], p_di .mean (0 )[idx ], color = color_mean )
303290 axes [count ].plot (new_x [idx ], p_di .T [idx ], color = color , alpha = alpha )
304- axes [count ].set_xlabel (x_labels [var ])
291+ axes [count ].set_xlabel (x_labels [i_var ])
305292
306293 count += 1
307294
@@ -349,7 +336,7 @@ def plot_pdp(
349336 For discrete variables these options are ommited.
350337 xs_values : Optional[Union[int, List[float]]], by default None.
351338 Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
352- points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
339+ points in the evenly spaced grid. If ``xs_interval="quantiles"`` quantile or sequence of
353340 quantiles to compute, which must be between 0 and 1 inclusive.
354341 Ignored when ``xs_interval="insample"``.
355342 var_idx : Optional[List[int]], by default None.
@@ -717,7 +704,8 @@ def plot_variable_importance(
717704 xlabel_angle : float = 0 ,
718705 samples : int = 100 ,
719706 random_seed : Optional [int ] = None ,
720- ) -> Tuple [List [int ], List [plt .Axes ]]:
707+ ax : Optional [plt .Axes ] = None ,
708+ ) -> Tuple [List [int ], Union [List [plt .Axes ], Any ]]:
721709 """
722710 Estimates variable importance from the BART-posterior.
723711
@@ -747,6 +735,8 @@ def plot_variable_importance(
747735 Number of predictions used to compute correlation for subsets of variables. Defaults to 100
748736 random_seed : Optional[int]
749737 random_seed used to sample from the posterior. Defaults to None.
738+ ax : axes
739+ Matplotlib axes.
750740
751741 Returns
752742 -------
@@ -771,7 +761,8 @@ def plot_variable_importance(
771761 if figsize is None :
772762 figsize = (8 , 3 )
773763
774- _ , ax = plt .subplots (1 , 1 , figsize = figsize )
764+ if ax is None :
765+ _ , ax = plt .subplots (1 , 1 , figsize = figsize )
775766
776767 if labels is None :
777768 labels_ary = np .arange (n_vars ).astype (str )
0 commit comments