@@ -256,3 +256,70 @@ def test_categorical_model(separate_trees, split_rule):
256256 # Fit should be good enough so right category is selected over 50% of time
257257 assert (idata .predictions .y .median (["chain" , "draw" ]) == Y ).all ()
258258 assert pmb .compute_variable_importance (idata , bartrv = lo , X = X )["preds" ].shape == (5 , 50 , 9 , 3 )
259+
260+
261+ def test_multiple_bart_variables ():
262+ """Test that multiple BART variables can coexist in a single model."""
263+ X1 = np .random .normal (0 , 1 , size = (50 , 2 ))
264+ X2 = np .random .normal (0 , 1 , size = (50 , 3 ))
265+ Y = np .random .normal (0 , 1 , size = 50 )
266+
267+ # Create correlated responses
268+ Y1 = X1 [:, 0 ] + np .random .normal (0 , 0.1 , size = 50 )
269+ Y2 = X2 [:, 0 ] + X2 [:, 1 ] + np .random .normal (0 , 0.1 , size = 50 )
270+
271+ with pm .Model () as model :
272+ # Two separate BART variables with different covariates
273+ mu1 = pmb .BART ("mu1" , X1 , Y1 , m = 5 )
274+ mu2 = pmb .BART ("mu2" , X2 , Y2 , m = 5 )
275+
276+ # Combined model
277+ sigma = pm .HalfNormal ("sigma" , 1 )
278+ y = pm .Normal ("y" , mu1 + mu2 , sigma , observed = Y )
279+
280+ # Sample with automatic assignment of BART samplers
281+ idata = pm .sample (tune = 50 , draws = 50 , chains = 1 , random_seed = 3415 )
282+
283+ # Verify both BART variables have their own tree collections
284+ assert hasattr (mu1 .owner .op , "all_trees" )
285+ assert hasattr (mu2 .owner .op , "all_trees" )
286+
287+ # Verify trees are stored separately (different object references)
288+ assert mu1 .owner .op .all_trees is not mu2 .owner .op .all_trees
289+
290+ # Verify sampling worked
291+ assert idata .posterior ["mu1" ].shape == (1 , 50 , 50 )
292+ assert idata .posterior ["mu2" ].shape == (1 , 50 , 50 )
293+
294+
295+ def test_multiple_bart_variables_manual_step ():
296+ """Test that multiple BART variables work with manually assigned PGBART samplers."""
297+ X1 = np .random .normal (0 , 1 , size = (30 , 2 ))
298+ X2 = np .random .normal (0 , 1 , size = (30 , 2 ))
299+ Y = np .random .normal (0 , 1 , size = 30 )
300+
301+ # Create simple responses
302+ Y1 = X1 [:, 0 ] + np .random .normal (0 , 0.1 , size = 30 )
303+ Y2 = X2 [:, 1 ] + np .random .normal (0 , 0.1 , size = 30 )
304+
305+ with pm .Model () as model :
306+ # Two separate BART variables
307+ mu1 = pmb .BART ("mu1" , X1 , Y1 , m = 3 )
308+ mu2 = pmb .BART ("mu2" , X2 , Y2 , m = 3 )
309+
310+ # Non-BART variable
311+ sigma = pm .HalfNormal ("sigma" , 1 )
312+ y = pm .Normal ("y" , mu1 + mu2 , sigma , observed = Y )
313+
314+ # Manually create PGBART samplers for each BART variable
315+ step1 = pmb .PGBART ([mu1 ], num_particles = 5 )
316+ step2 = pmb .PGBART ([mu2 ], num_particles = 5 )
317+
318+ # Sample with manual step assignment
319+ idata = pm .sample (tune = 20 , draws = 20 , chains = 1 , step = [step1 , step2 ], random_seed = 3415 )
320+
321+ # Verify both variables were sampled
322+ assert "mu1" in idata .posterior
323+ assert "mu2" in idata .posterior
324+ assert idata .posterior ["mu1" ].shape == (1 , 20 , 30 )
325+ assert idata .posterior ["mu2" ].shape == (1 , 20 , 30 )
0 commit comments