11import numpy as np
22import pymc as pm
33import pytest
4- from numpy .testing import assert_almost_equal , assert_array_equal
4+ from numpy .testing import assert_almost_equal
55from pymc .initial_point import make_initial_point_fn
66from pymc .logprob .basic import transformed_conditional_logp
77
88import pymc_bart as pmb
9+ from pymc_bart .utils import _decode_vi
910
1011
1112def assert_moment_is_expected (model , expected , check_finite_logp = True ):
@@ -52,14 +53,12 @@ def test_bart_vi(response):
5253 with pm .Model () as model :
5354 mu = pmb .BART ("mu" , X , Y , m = 10 , response = response )
5455 sigma = pm .HalfNormal ("sigma" , 1 )
55- y = pm .Normal ("y" , mu , sigma , observed = Y )
56+ pm .Normal ("y" , mu , sigma , observed = Y )
5657 idata = pm .sample (tune = 200 , draws = 200 , random_seed = 3415 )
57- var_imp = (
58- idata .sample_stats ["variable_inclusion" ]
59- .stack (samples = ("chain" , "draw" ))
60- .mean ("samples" )
61- )
62- var_imp /= var_imp .sum ()
58+ vi_vals = idata ["sample_stats" ]["variable_inclusion" ].values .ravel ()
59+ var_imp = np .array ([_decode_vi (val , 3 ) for val in vi_vals ]).sum (axis = 0 )
60+
61+ var_imp = var_imp / var_imp .sum ()
6362 assert var_imp [0 ] > var_imp [1 :].sum ()
6463 assert_almost_equal (var_imp .sum (), 1 )
6564
@@ -123,92 +122,6 @@ def test_shape(response):
123122 assert idata .posterior .coords ["w_dim_1" ].data .size == 250
124123
125124
126- class TestUtils :
127- X_norm = np .random .normal (0 , 1 , size = (50 , 2 ))
128- X_binom = np .random .binomial (1 , 0.5 , size = (50 , 1 ))
129- X = np .hstack ([X_norm , X_binom ])
130- Y = np .random .normal (0 , 1 , size = 50 )
131-
132- with pm .Model () as model :
133- mu = pmb .BART ("mu" , X , Y , m = 10 )
134- sigma = pm .HalfNormal ("sigma" , 1 )
135- y = pm .Normal ("y" , mu , sigma , observed = Y )
136- idata = pm .sample (tune = 200 , draws = 200 , random_seed = 3415 )
137-
138- def test_sample_posterior (self ):
139- all_trees = self .mu .owner .op .all_trees
140- rng = np .random .default_rng (3 )
141- pred_all = pmb .utils ._sample_posterior (all_trees , X = self .X , rng = rng , size = 2 )
142- rng = np .random .default_rng (3 )
143- pred_first = pmb .utils ._sample_posterior (all_trees , X = self .X [:10 ], rng = rng )
144-
145- assert_almost_equal (pred_first [0 ], pred_all [0 , :10 ], decimal = 4 )
146- assert pred_all .shape == (2 , 50 , 1 )
147- assert pred_first .shape == (1 , 10 , 1 )
148-
149- @pytest .mark .parametrize (
150- "kwargs" ,
151- [
152- {},
153- {
154- "samples" : 2 ,
155- "var_discrete" : [3 ],
156- },
157- {"instances" : 2 },
158- {"var_idx" : [0 ], "smooth" : False , "color" : "k" },
159- {"grid" : (1 , 2 ), "sharey" : "none" , "alpha" : 1 },
160- {"var_discrete" : [0 ]},
161- ],
162- )
163- def test_ice (self , kwargs ):
164- pmb .plot_ice (self .mu , X = self .X , Y = self .Y , ** kwargs )
165-
166- @pytest .mark .parametrize (
167- "kwargs" ,
168- [
169- {},
170- {
171- "samples" : 2 ,
172- "xs_interval" : "quantiles" ,
173- "xs_values" : [0.25 , 0.5 , 0.75 ],
174- "var_discrete" : [3 ],
175- },
176- {"var_idx" : [0 ], "smooth" : False , "color" : "k" },
177- {"grid" : (1 , 2 ), "sharey" : "none" , "alpha" : 1 },
178- {"var_discrete" : [0 ]},
179- ],
180- )
181- def test_pdp (self , kwargs ):
182- pmb .plot_pdp (self .mu , X = self .X , Y = self .Y , ** kwargs )
183-
184- @pytest .mark .parametrize (
185- "kwargs" ,
186- [
187- {"samples" : 50 },
188- {"labels" : ["A" , "B" , "C" ], "samples" : 2 , "figsize" : (6 , 6 )},
189- ],
190- )
191- def test_vi (self , kwargs ):
192- samples = kwargs .pop ("samples" )
193- vi_results = pmb .compute_variable_importance (
194- self .idata , bartrv = self .mu , X = self .X , samples = samples
195- )
196- pmb .plot_variable_importance (vi_results , ** kwargs )
197- pmb .plot_scatter_submodels (vi_results , ** kwargs )
198-
199- def test_pdp_pandas_labels (self ):
200- pd = pytest .importorskip ("pandas" )
201-
202- X_names = ["norm1" , "norm2" , "binom" ]
203- X_pd = pd .DataFrame (self .X , columns = X_names )
204- Y_pd = pd .Series (self .Y , name = "response" )
205- axes = pmb .plot_pdp (self .mu , X = X_pd , Y = Y_pd )
206-
207- figure = axes [0 ].figure
208- assert figure .texts [0 ].get_text () == "Partial response"
209- assert_array_equal ([ax .get_xlabel () for ax in axes ], X_names )
210-
211-
212125@pytest .mark .parametrize (
213126 "size, expected" ,
214127 [
@@ -275,7 +188,7 @@ def test_multiple_bart_variables():
275188
276189 # Combined model
277190 sigma = pm .HalfNormal ("sigma" , 1 )
278- y = pm .Normal ("y" , mu1 + mu2 , sigma , observed = Y )
191+ pm .Normal ("y" , mu1 + mu2 , sigma , observed = Y )
279192
280193 # Sample with automatic assignment of BART samplers
281194 idata = pm .sample (tune = 50 , draws = 50 , chains = 1 , random_seed = 3415 )
@@ -291,6 +204,16 @@ def test_multiple_bart_variables():
291204 assert idata .posterior ["mu1" ].shape == (1 , 50 , 50 )
292205 assert idata .posterior ["mu2" ].shape == (1 , 50 , 50 )
293206
207+ vi_results = pmb .compute_variable_importance (idata , mu1 , X1 , model = model )
208+ assert vi_results ["labels" ].shape == (2 ,)
209+ assert vi_results ["preds" ].shape == (2 , 50 , 50 )
210+ assert vi_results ["preds_all" ].shape == (50 , 50 )
211+
212+ vi_tuple = pmb .get_variable_inclusion (idata , X1 , model = model , bart_var_name = "mu1" )
213+ assert vi_tuple [0 ].shape == (2 ,)
214+ assert len (vi_tuple [1 ]) == 2
215+ assert isinstance (vi_tuple [1 ][0 ], str )
216+
294217
295218def test_multiple_bart_variables_manual_step ():
296219 """Test that multiple BART variables work with manually assigned PGBART samplers."""
0 commit comments