@@ -285,6 +285,76 @@ test_that("BART predictions with pre-summarization", {
285285 expect_equal(sigma2_hat_mean_test , sigma2_hat_mean_test_single_term )
286286})
287287
288+ test_that(" BART predictions with random effects" , {
289+ # Generate data and test-train split
290+ n <- 100
291+ p <- 5
292+ X <- matrix (runif(n * p ), ncol = p )
293+ # fmt: skip
294+ f_XW <- (((0 < = X [, 1 ]) & (0.25 > X [, 1 ])) * (- 7.5 ) +
295+ ((0.25 < = X [, 1 ]) & (0.5 > X [, 1 ])) * (- 2.5 ) +
296+ ((0.5 < = X [, 1 ]) & (0.75 > X [, 1 ])) * (2.5 ) +
297+ ((0.75 < = X [, 1 ]) & (1 > X [, 1 ])) * (7.5 ))
298+ noise_sd <- 1
299+ rfx_group_ids <- sample(1 : 3 , n , replace = TRUE )
300+ rfx_coefs <- c(- 2 , 0 , 2 )
301+ rfx_term <- rfx_coefs [rfx_group_ids ]
302+ rfx_basis <- matrix (1 , nrow = n , ncol = 1 )
303+ y <- f_XW + rfx_term + rnorm(n , 0 , noise_sd )
304+ test_set_pct <- 0.2
305+ n_test <- round(test_set_pct * n )
306+ n_train <- n - n_test
307+ test_inds <- sort(sample(1 : n , n_test , replace = FALSE ))
308+ train_inds <- (1 : n )[! ((1 : n ) %in% test_inds )]
309+ X_test <- X [test_inds , ]
310+ X_train <- X [train_inds , ]
311+ rfx_group_ids_test <- rfx_group_ids [test_inds ]
312+ rfx_group_ids_train <- rfx_group_ids [train_inds ]
313+ rfx_basis_test <- rfx_basis [test_inds , ]
314+ rfx_basis_train <- rfx_basis [train_inds , ]
315+ y_test <- y [test_inds ]
316+ y_train <- y [train_inds ]
317+
318+ # Fit a "classic" BART model
319+ rfx_params <- list (model_spec = " intercept_only" )
320+ bart_model <- bart(
321+ X_train = X_train ,
322+ y_train = y_train ,
323+ rfx_group_ids_train = rfx_group_ids_train ,
324+ random_effects_params = rfx_params ,
325+ num_gfr = 10 ,
326+ num_burnin = 0 ,
327+ num_mcmc = 10
328+ )
329+
330+ # Check that the default predict method returns a list
331+ pred <- predict(bart_model , X = X_test , rfx_group_ids = rfx_group_ids_test )
332+ y_hat_posterior_test <- pred $ y_hat
333+ expect_equal(dim(y_hat_posterior_test ), c(20 , 10 ))
334+
335+ # Check that the pre-aggregated predictions match with those computed by rowMeans
336+ pred_mean <- predict(
337+ bart_model ,
338+ X = X_test ,
339+ rfx_group_ids = rfx_group_ids_test ,
340+ type = " mean"
341+ )
342+ y_hat_mean_test <- pred_mean $ y_hat
343+ expect_equal(y_hat_mean_test , rowMeans(y_hat_posterior_test ))
344+
345+ # Check that we warn and return a NULL when requesting terms that weren't fit
346+ expect_warning({
347+ pred_mean <- predict(
348+ bart_model ,
349+ X = X_test ,
350+ rfx_group_ids = rfx_group_ids_test ,
351+ type = " mean" ,
352+ terms = c(" variance_forest" )
353+ )
354+ })
355+ expect_equal(NULL , pred_mean )
356+ })
357+
288358test_that(" BCF predictions with pre-summarization" , {
289359 # Generate data and test-train split
290360 n <- 100
@@ -443,3 +513,171 @@ test_that("BCF predictions with pre-summarization", {
443513 expect_equal(y_hat_mean_test , y_hat_mean_test_single_term )
444514 expect_equal(sigma2_hat_mean_test , sigma2_hat_mean_test_single_term )
445515})
516+
517+ test_that(" BCF predictions with random effects" , {
518+ # Generate data and test-train split
519+ n <- 100
520+ g <- function (x ) {
521+ ifelse(x [, 5 ] == 1 , 2 , ifelse(x [, 5 ] == 2 , - 1 , - 4 ))
522+ }
523+ x1 <- rnorm(n )
524+ x2 <- rnorm(n )
525+ x3 <- rnorm(n )
526+ x4 <- as.numeric(rbinom(n , 1 , 0.5 ))
527+ x5 <- as.numeric(sample(1 : 3 , n , replace = TRUE ))
528+ X <- cbind(x1 , x2 , x3 , x4 , x5 )
529+ p <- ncol(X )
530+ mu_x <- 1 + g(X ) + X [, 1 ] * X [, 3 ]
531+ tau_x <- 1 + 2 * X [, 2 ] * X [, 4 ]
532+ pi_x <- 0.8 *
533+ pnorm((3 * mu_x / sd(mu_x )) - 0.5 * X [, 1 ]) +
534+ 0.05 +
535+ runif(n ) / 10
536+ Z <- rbinom(n , 1 , pi_x )
537+ E_XZ <- mu_x + Z * tau_x
538+ rfx_group_ids <- sample(1 : 3 , n , replace = TRUE )
539+ rfx_basis <- cbind(1 , Z )
540+ rfx_coefs <- matrix (
541+ c(
542+ - 2 ,
543+ - 0.5 ,
544+ 0 ,
545+ 0.0 ,
546+ 2 ,
547+ 0.5
548+ ),
549+ byrow = T ,
550+ ncol = 2
551+ )
552+ rfx_term <- rowSums(rfx_basis * rfx_coefs [rfx_group_ids , ])
553+ snr <- 2
554+ y <- E_XZ + rfx_term + rnorm(n , 0 , 1 ) * (sd(E_XZ + rfx_term ) / snr )
555+ X <- as.data.frame(X )
556+ X $ x4 <- factor (X $ x4 , ordered = TRUE )
557+ X $ x5 <- factor (X $ x5 , ordered = TRUE )
558+ test_set_pct <- 0.2
559+ n_test <- round(test_set_pct * n )
560+ n_train <- n - n_test
561+ test_inds <- sort(sample(1 : n , n_test , replace = FALSE ))
562+ train_inds <- (1 : n )[! ((1 : n ) %in% test_inds )]
563+ X_test <- X [test_inds , ]
564+ X_train <- X [train_inds , ]
565+ pi_test <- pi_x [test_inds ]
566+ pi_train <- pi_x [train_inds ]
567+ rfx_group_ids_test <- rfx_group_ids [test_inds ]
568+ rfx_group_ids_train <- rfx_group_ids [train_inds ]
569+ rfx_basis_test <- rfx_basis [test_inds , ]
570+ rfx_basis_train <- rfx_basis [train_inds , ]
571+ Z_test <- Z [test_inds ]
572+ Z_train <- Z [train_inds ]
573+ y_test <- y [test_inds ]
574+ y_train <- y [train_inds ]
575+
576+ # Fit a BCF model with random intercept and random slope on Z
577+ rfx_params = list (model_spec = " intercept_plus_treatment" )
578+ bcf_model <- bcf(
579+ X_train = X_train ,
580+ Z_train = Z_train ,
581+ y_train = y_train ,
582+ propensity_train = pi_train ,
583+ rfx_group_ids_train = rfx_group_ids_train ,
584+ X_test = X_test ,
585+ Z_test = Z_test ,
586+ propensity_test = pi_test ,
587+ rfx_group_ids_test = rfx_group_ids_test ,
588+ random_effects_params = rfx_params ,
589+ num_gfr = 10 ,
590+ num_burnin = 0 ,
591+ num_mcmc = 10
592+ )
593+
594+ # Check that the default predict method returns a list
595+ pred <- predict(
596+ bcf_model ,
597+ X = X_test ,
598+ Z = Z_test ,
599+ propensity = pi_test ,
600+ rfx_group_ids = rfx_group_ids_test
601+ )
602+ y_hat_posterior_test <- pred $ y_hat
603+ expect_equal(dim(y_hat_posterior_test ), c(20 , 10 ))
604+
605+ # Check that the pre-aggregated predictions match with those computed by rowMeans
606+ pred_mean <- predict(
607+ bcf_model ,
608+ X = X_test ,
609+ Z = Z_test ,
610+ propensity = pi_test ,
611+ rfx_group_ids = rfx_group_ids_test ,
612+ type = " mean"
613+ )
614+ y_hat_mean_test <- pred_mean $ y_hat
615+ expect_equal(y_hat_mean_test , rowMeans(y_hat_posterior_test ))
616+
617+ # Check that we warn and return a NULL when requesting terms that weren't fit
618+ expect_warning({
619+ pred_mean <- predict(
620+ bcf_model ,
621+ X = X_test ,
622+ Z = Z_test ,
623+ propensity = pi_test ,
624+ type = " mean" ,
625+ terms = c(" variance_forest" )
626+ )
627+ })
628+ expect_equal(NULL , pred_mean )
629+
630+ # Fit a BCF model with random intercept only
631+ # Fit a BCF model with random intercept and random slope on Z
632+ rfx_params = list (model_spec = " intercept_only" )
633+ bcf_model <- bcf(
634+ X_train = X_train ,
635+ Z_train = Z_train ,
636+ y_train = y_train ,
637+ propensity_train = pi_train ,
638+ rfx_group_ids_train = rfx_group_ids_train ,
639+ X_test = X_test ,
640+ Z_test = Z_test ,
641+ propensity_test = pi_test ,
642+ rfx_group_ids_test = rfx_group_ids_test ,
643+ random_effects_params = rfx_params ,
644+ num_gfr = 10 ,
645+ num_burnin = 0 ,
646+ num_mcmc = 10
647+ )
648+
649+ # Check that the default predict method returns a list
650+ pred <- predict(
651+ bcf_model ,
652+ X = X_test ,
653+ Z = Z_test ,
654+ propensity = pi_test ,
655+ rfx_group_ids = rfx_group_ids_test
656+ )
657+ y_hat_posterior_test <- pred $ y_hat
658+ expect_equal(dim(y_hat_posterior_test ), c(20 , 10 ))
659+
660+ # Check that the pre-aggregated predictions match with those computed by rowMeans
661+ pred_mean <- predict(
662+ bcf_model ,
663+ X = X_test ,
664+ Z = Z_test ,
665+ propensity = pi_test ,
666+ rfx_group_ids = rfx_group_ids_test ,
667+ type = " mean"
668+ )
669+ y_hat_mean_test <- pred_mean $ y_hat
670+ expect_equal(y_hat_mean_test , rowMeans(y_hat_posterior_test ))
671+
672+ # Check that we warn and return a NULL when requesting terms that weren't fit
673+ expect_warning({
674+ pred_mean <- predict(
675+ bcf_model ,
676+ X = X_test ,
677+ Z = Z_test ,
678+ propensity = pi_test ,
679+ type = " mean" ,
680+ terms = c(" variance_forest" )
681+ )
682+ })
683+ })
0 commit comments