File tree Expand file tree Collapse file tree 2 files changed +17
-7
lines changed Expand file tree Collapse file tree 2 files changed +17
-7
lines changed Original file line number Diff line number Diff line change @@ -74,9 +74,9 @@ class BART(Distribution):
7474
7575 Parameters
7676 ----------
77- X : TensorLike
77+ X : PyTensor Variable, Pandas/Polars DataFrame or Numpy array
7878 The covariate matrix.
79- Y : TensorLike
79+ Y : PyTensor Variable, Pandas/Polar DataFrame/Series,or Numpy array
8080 The response vector.
8181 m : int
8282 Number of trees.
@@ -204,6 +204,16 @@ def preprocess_xy(
204204 if isinstance (X , (Series , DataFrame )):
205205 X = X .to_numpy ()
206206
207+ try :
208+ import polars as pl
209+
210+ if isinstance (X , (pl .Series , pl .DataFrame )):
211+ X = X .to_numpy ()
212+ if isinstance (Y , (pl .Series , pl .DataFrame )):
213+ Y = Y .to_numpy ()
214+ except ImportError :
215+ pass
216+
207217 Y = Y .astype (float )
208218 X = X .astype (float )
209219
Original file line number Diff line number Diff line change @@ -546,7 +546,7 @@ def _prepare_plot_data(
546546
547547 Parameters
548548 ----------
549- X : PyTensor Variable, Pandas DataFrame or Numpy array
549+ X : PyTensor Variable, Pandas DataFrame, Polars DataFrame or Numpy array
550550 Input data.
551551 Y : array-like
552552 Target data.
@@ -585,9 +585,9 @@ def _prepare_plot_data(
585585 if isinstance (X , Variable ):
586586 X = X .eval ()
587587
588- if hasattr (X , "columns" ) and hasattr (X , "values " ):
588+ if hasattr (X , "columns" ) and hasattr (X , "to_numpy " ):
589589 x_names = list (X .columns )
590- X = X .values
590+ X = X .to_numpy ()
591591 else :
592592 x_names = []
593593
@@ -750,9 +750,9 @@ def plot_variable_importance( # noqa: PLR0915
750750 else :
751751 shape = bartrv .eval ().shape [0 ]
752752
753- if hasattr (X , "columns" ) and hasattr (X , "values " ):
753+ if hasattr (X , "columns" ) and hasattr (X , "to_numpy " ):
754754 labels = X .columns
755- X = X .values
755+ X = X .to_numpy ()
756756
757757 n_vars = X .shape [1 ]
758758
You can’t perform that action at this time.
0 commit comments