11"""Utility function for variable selection and bart interpretability."""
22
3+ import warnings
4+
35import arviz as az
46import matplotlib .pyplot as plt
57import numpy as np
@@ -287,6 +289,7 @@ def plot_dependence(
287289 y_mins .append (np .min (y_pred ))
288290 new_y .append (np .array (y_pred ).T )
289291
292+ new_y = np .array (new_y )
290293 if func is not None :
291294 new_y = [func (nyi ) for nyi in new_y ]
292295 shape = 1
@@ -299,6 +302,14 @@ def plot_dependence(
299302 fig , axes = plt .subplots (1 , len (var_idx ) * shape , sharey = sharey , figsize = figsize )
300303 elif isinstance (grid , tuple ):
301304 fig , axes = plt .subplots (grid [0 ], grid [1 ], sharey = sharey , figsize = figsize )
305+ grid_size = grid [0 ] * grid [1 ]
306+ n_plots = new_y .squeeze ().shape [0 ]
307+ if n_plots > grid_size :
308+ warnings .warn ("The grid is smaller than the number of available variables to plot" )
309+ elif n_plots < grid_size :
310+ for i in range (n_plots , grid [0 ] * grid [1 ]):
311+ fig .delaxes (axes .flatten ()[i ])
312+ axes = axes .flatten ()[:n_plots ]
302313 axes = np .ravel (axes )
303314 else :
304315 axes = [ax ]
@@ -307,10 +318,6 @@ def plot_dependence(
307318 x_idx = 0
308319 y_idx = 0
309320 for ax in axes : # pylint: disable=redefined-argument-from-local
310- if x_idx >= len (var_idx ):
311- ax .set_axis_off ()
312- fig .delaxes (ax )
313-
314321 nyi = new_y [x_idx ][y_idx ]
315322 nxi = new_x_target [x_idx ]
316323 var = var_idx [x_idx ]
0 commit comments