33import pytest
44from numpy .testing import assert_almost_equal , assert_array_equal
55from pymc .initial_point import make_initial_point_fn
6- from pymc .logprob .basic import joint_logp
6+ from pymc .logprob .basic import transformed_conditional_logp
77
88import pymc_bart as pmb
99
@@ -12,7 +12,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
1212 fn = make_initial_point_fn (
1313 model = model ,
1414 return_transformed = False ,
15- default_strategy = "moment " ,
15+ default_strategy = "support_point " ,
1616 )
1717 moment = fn (0 )["x" ]
1818 expected = np .asarray (expected )
@@ -27,7 +27,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
2727
2828 if check_finite_logp :
2929 logp_moment = (
30- joint_logp (
30+ transformed_conditional_logp (
3131 (model ["x" ],),
3232 rvs_to_values = {model ["x" ]: pm .math .constant (moment )},
3333 rvs_to_transforms = {},
@@ -53,7 +53,7 @@ def test_bart_vi(response):
5353 mu = pmb .BART ("mu" , X , Y , m = 10 , response = response )
5454 sigma = pm .HalfNormal ("sigma" , 1 )
5555 y = pm .Normal ("y" , mu , sigma , observed = Y )
56- idata = pm .sample (random_seed = 3415 )
56+ idata = pm .sample (tune = 200 , draws = 200 , random_seed = 3415 )
5757 var_imp = (
5858 idata .sample_stats ["variable_inclusion" ]
5959 .stack (samples = ("chain" , "draw" ))
@@ -77,8 +77,8 @@ def test_missing_data(response):
7777 with pm .Model () as model :
7878 mu = pmb .BART ("mu" , X , Y , m = 10 , response = response )
7979 sigma = pm .HalfNormal ("sigma" , 1 )
80- y = pm .Normal ("y" , mu , sigma , observed = Y )
81- idata = pm .sample (tune = 100 , draws = 100 , chains = 1 , random_seed = 3415 )
80+ pm .Normal ("y" , mu , sigma , observed = Y )
81+ pm .sample (tune = 100 , draws = 100 , chains = 1 , random_seed = 3415 )
8282
8383
8484@pytest .mark .parametrize (
@@ -91,7 +91,7 @@ def test_shared_variable(response):
9191 Y = np .random .normal (0 , 1 , size = 50 )
9292
9393 with pm .Model () as model :
94- data_X = pm .MutableData ("data_X" , X )
94+ data_X = pm .Data ("data_X" , X )
9595 mu = pmb .BART ("mu" , data_X , Y , m = 2 , response = response )
9696 sigma = pm .HalfNormal ("sigma" , 1 )
9797 y = pm .Normal ("y" , mu , sigma , observed = Y , shape = mu .shape )
@@ -116,7 +116,7 @@ def test_shape(response):
116116 with pm .Model () as model :
117117 w = pmb .BART ("w" , X , Y , m = 2 , response = response , shape = (2 , 250 ))
118118 y = pm .Normal ("y" , w [0 ], pm .math .abs (w [1 ]), observed = Y )
119- idata = pm .sample (random_seed = 3415 )
119+ idata = pm .sample (tune = 50 , draws = 10 , random_seed = 3415 )
120120
121121 assert model .initial_point ()["w" ].shape == (2 , 250 )
122122 assert idata .posterior .coords ["w_dim_0" ].data .size == 2
@@ -133,7 +133,7 @@ class TestUtils:
133133 mu = pmb .BART ("mu" , X , Y , m = 10 )
134134 sigma = pm .HalfNormal ("sigma" , 1 )
135135 y = pm .Normal ("y" , mu , sigma , observed = Y )
136- idata = pm .sample (random_seed = 3415 )
136+ idata = pm .sample (tune = 200 , draws = 200 , random_seed = 3415 )
137137
138138 def test_sample_posterior (self ):
139139 all_trees = self .mu .owner .op .all_trees
0 commit comments