Skip to content

Commit 74a2c25

Browse files
authored
Merge pull request #248 from StochasticTree/rfx-pred-patch
Promote rfx_beta_draws to consistent array dimensions in predict.bcf
2 parents 32acbd1 + 3cb4804 commit 74a2c25

File tree

4 files changed

+356
-0
lines changed

4 files changed

+356
-0
lines changed

R/bart.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2219,6 +2219,11 @@ predict.bartmodel <- function(
22192219
rfx_param_list <- object$rfx_samples$extract_parameter_samples()
22202220
rfx_beta_draws <- rfx_param_list$beta_samples * y_std
22212221

2222+
# Promote to an array with consistent dimensions when there's one rfx term
2223+
if (length(dim(rfx_beta_draws)) == 2) {
2224+
dim(rfx_beta_draws) <- c(1, dim(rfx_beta_draws))
2225+
}
2226+
22222227
# Construct a matrix with the appropriate group random effects arranged for each observation
22232228
rfx_predictions_raw <- array(
22242229
NA,

R/bcf.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3162,6 +3162,11 @@ predict.bcfmodel <- function(
31623162
rfx_beta_draws <- rfx_param_list$beta_samples *
31633163
object$model_params$outcome_scale
31643164

3165+
# Promote to an array with consistent dimensions when there's one rfx term
3166+
if (length(dim(rfx_beta_draws)) == 2) {
3167+
dim(rfx_beta_draws) <- c(1, dim(rfx_beta_draws))
3168+
}
3169+
31653170
# Construct a matrix with the appropriate group random effects arranged for each observation
31663171
rfx_predictions_raw <- array(
31673172
NA,

test/R/testthat/test-predict.R

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,76 @@ test_that("BART predictions with pre-summarization", {
285285
expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term)
286286
})
287287

288+
test_that("BART predictions with random effects", {
289+
# Generate data and test-train split
290+
n <- 100
291+
p <- 5
292+
X <- matrix(runif(n * p), ncol = p)
293+
# fmt: skip
294+
f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) +
295+
((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) +
296+
((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) +
297+
((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5))
298+
noise_sd <- 1
299+
rfx_group_ids <- sample(1:3, n, replace = TRUE)
300+
rfx_coefs <- c(-2, 0, 2)
301+
rfx_term <- rfx_coefs[rfx_group_ids]
302+
rfx_basis <- matrix(1, nrow = n, ncol = 1)
303+
y <- f_XW + rfx_term + rnorm(n, 0, noise_sd)
304+
test_set_pct <- 0.2
305+
n_test <- round(test_set_pct * n)
306+
n_train <- n - n_test
307+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
308+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
309+
X_test <- X[test_inds, ]
310+
X_train <- X[train_inds, ]
311+
rfx_group_ids_test <- rfx_group_ids[test_inds]
312+
rfx_group_ids_train <- rfx_group_ids[train_inds]
313+
rfx_basis_test <- rfx_basis[test_inds, ]
314+
rfx_basis_train <- rfx_basis[train_inds, ]
315+
y_test <- y[test_inds]
316+
y_train <- y[train_inds]
317+
318+
# Fit a "classic" BART model
319+
rfx_params <- list(model_spec = "intercept_only")
320+
bart_model <- bart(
321+
X_train = X_train,
322+
y_train = y_train,
323+
rfx_group_ids_train = rfx_group_ids_train,
324+
random_effects_params = rfx_params,
325+
num_gfr = 10,
326+
num_burnin = 0,
327+
num_mcmc = 10
328+
)
329+
330+
# Check that the default predict method returns a list
331+
pred <- predict(bart_model, X = X_test, rfx_group_ids = rfx_group_ids_test)
332+
y_hat_posterior_test <- pred$y_hat
333+
expect_equal(dim(y_hat_posterior_test), c(20, 10))
334+
335+
# Check that the pre-aggregated predictions match with those computed by rowMeans
336+
pred_mean <- predict(
337+
bart_model,
338+
X = X_test,
339+
rfx_group_ids = rfx_group_ids_test,
340+
type = "mean"
341+
)
342+
y_hat_mean_test <- pred_mean$y_hat
343+
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
344+
345+
# Check that we warn and return a NULL when requesting terms that weren't fit
346+
expect_warning({
347+
pred_mean <- predict(
348+
bart_model,
349+
X = X_test,
350+
rfx_group_ids = rfx_group_ids_test,
351+
type = "mean",
352+
terms = c("variance_forest")
353+
)
354+
})
355+
expect_equal(NULL, pred_mean)
356+
})
357+
288358
test_that("BCF predictions with pre-summarization", {
289359
# Generate data and test-train split
290360
n <- 100
@@ -443,3 +513,171 @@ test_that("BCF predictions with pre-summarization", {
443513
expect_equal(y_hat_mean_test, y_hat_mean_test_single_term)
444514
expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term)
445515
})
516+
517+
test_that("BCF predictions with random effects", {
518+
# Generate data and test-train split
519+
n <- 100
520+
g <- function(x) {
521+
ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4))
522+
}
523+
x1 <- rnorm(n)
524+
x2 <- rnorm(n)
525+
x3 <- rnorm(n)
526+
x4 <- as.numeric(rbinom(n, 1, 0.5))
527+
x5 <- as.numeric(sample(1:3, n, replace = TRUE))
528+
X <- cbind(x1, x2, x3, x4, x5)
529+
p <- ncol(X)
530+
mu_x <- 1 + g(X) + X[, 1] * X[, 3]
531+
tau_x <- 1 + 2 * X[, 2] * X[, 4]
532+
pi_x <- 0.8 *
533+
pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) +
534+
0.05 +
535+
runif(n) / 10
536+
Z <- rbinom(n, 1, pi_x)
537+
E_XZ <- mu_x + Z * tau_x
538+
rfx_group_ids <- sample(1:3, n, replace = TRUE)
539+
rfx_basis <- cbind(1, Z)
540+
rfx_coefs <- matrix(
541+
c(
542+
-2,
543+
-0.5,
544+
0,
545+
0.0,
546+
2,
547+
0.5
548+
),
549+
byrow = T,
550+
ncol = 2
551+
)
552+
rfx_term <- rowSums(rfx_basis * rfx_coefs[rfx_group_ids, ])
553+
snr <- 2
554+
y <- E_XZ + rfx_term + rnorm(n, 0, 1) * (sd(E_XZ + rfx_term) / snr)
555+
X <- as.data.frame(X)
556+
X$x4 <- factor(X$x4, ordered = TRUE)
557+
X$x5 <- factor(X$x5, ordered = TRUE)
558+
test_set_pct <- 0.2
559+
n_test <- round(test_set_pct * n)
560+
n_train <- n - n_test
561+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
562+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
563+
X_test <- X[test_inds, ]
564+
X_train <- X[train_inds, ]
565+
pi_test <- pi_x[test_inds]
566+
pi_train <- pi_x[train_inds]
567+
rfx_group_ids_test <- rfx_group_ids[test_inds]
568+
rfx_group_ids_train <- rfx_group_ids[train_inds]
569+
rfx_basis_test <- rfx_basis[test_inds, ]
570+
rfx_basis_train <- rfx_basis[train_inds, ]
571+
Z_test <- Z[test_inds]
572+
Z_train <- Z[train_inds]
573+
y_test <- y[test_inds]
574+
y_train <- y[train_inds]
575+
576+
# Fit a BCF model with random intercept and random slope on Z
577+
rfx_params = list(model_spec = "intercept_plus_treatment")
578+
bcf_model <- bcf(
579+
X_train = X_train,
580+
Z_train = Z_train,
581+
y_train = y_train,
582+
propensity_train = pi_train,
583+
rfx_group_ids_train = rfx_group_ids_train,
584+
X_test = X_test,
585+
Z_test = Z_test,
586+
propensity_test = pi_test,
587+
rfx_group_ids_test = rfx_group_ids_test,
588+
random_effects_params = rfx_params,
589+
num_gfr = 10,
590+
num_burnin = 0,
591+
num_mcmc = 10
592+
)
593+
594+
# Check that the default predict method returns a list
595+
pred <- predict(
596+
bcf_model,
597+
X = X_test,
598+
Z = Z_test,
599+
propensity = pi_test,
600+
rfx_group_ids = rfx_group_ids_test
601+
)
602+
y_hat_posterior_test <- pred$y_hat
603+
expect_equal(dim(y_hat_posterior_test), c(20, 10))
604+
605+
# Check that the pre-aggregated predictions match with those computed by rowMeans
606+
pred_mean <- predict(
607+
bcf_model,
608+
X = X_test,
609+
Z = Z_test,
610+
propensity = pi_test,
611+
rfx_group_ids = rfx_group_ids_test,
612+
type = "mean"
613+
)
614+
y_hat_mean_test <- pred_mean$y_hat
615+
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
616+
617+
# Check that we warn and return a NULL when requesting terms that weren't fit
618+
expect_warning({
619+
pred_mean <- predict(
620+
bcf_model,
621+
X = X_test,
622+
Z = Z_test,
623+
propensity = pi_test,
624+
type = "mean",
625+
terms = c("variance_forest")
626+
)
627+
})
628+
expect_equal(NULL, pred_mean)
629+
630+
# Fit a BCF model with random intercept only
631+
# Fit a BCF model with random intercept and random slope on Z
632+
rfx_params = list(model_spec = "intercept_only")
633+
bcf_model <- bcf(
634+
X_train = X_train,
635+
Z_train = Z_train,
636+
y_train = y_train,
637+
propensity_train = pi_train,
638+
rfx_group_ids_train = rfx_group_ids_train,
639+
X_test = X_test,
640+
Z_test = Z_test,
641+
propensity_test = pi_test,
642+
rfx_group_ids_test = rfx_group_ids_test,
643+
random_effects_params = rfx_params,
644+
num_gfr = 10,
645+
num_burnin = 0,
646+
num_mcmc = 10
647+
)
648+
649+
# Check that the default predict method returns a list
650+
pred <- predict(
651+
bcf_model,
652+
X = X_test,
653+
Z = Z_test,
654+
propensity = pi_test,
655+
rfx_group_ids = rfx_group_ids_test
656+
)
657+
y_hat_posterior_test <- pred$y_hat
658+
expect_equal(dim(y_hat_posterior_test), c(20, 10))
659+
660+
# Check that the pre-aggregated predictions match with those computed by rowMeans
661+
pred_mean <- predict(
662+
bcf_model,
663+
X = X_test,
664+
Z = Z_test,
665+
propensity = pi_test,
666+
rfx_group_ids = rfx_group_ids_test,
667+
type = "mean"
668+
)
669+
y_hat_mean_test <- pred_mean$y_hat
670+
expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test))
671+
672+
# Check that we warn and return a NULL when requesting terms that weren't fit
673+
expect_warning({
674+
pred_mean <- predict(
675+
bcf_model,
676+
X = X_test,
677+
Z = Z_test,
678+
propensity = pi_test,
679+
type = "mean",
680+
terms = c("variance_forest")
681+
)
682+
})
683+
})

test/python/test_predict.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,45 @@ def test_bart_prediction(self):
275275
sigma2_hat_mean_test, sigma2_hat_mean_test_single_term
276276
)
277277

278+
# Generate data with random effects
279+
rfx_group_ids = rng.choice(3, size=n)
280+
rfx_basis = np.ones((n, 1))
281+
rfx_coefs = np.array([-2.0, 0.0, 2.0])
282+
rfx_term = rfx_coefs[rfx_group_ids]
283+
noise_sd = 1
284+
y = f_XW + rfx_term + rng.normal(0, noise_sd, size=n)
285+
test_set_pct = 0.2
286+
train_inds, test_inds = train_test_split(
287+
np.arange(n), test_size=test_set_pct, random_state=1234
288+
)
289+
X_train = X[train_inds, :]
290+
X_test = X[test_inds, :]
291+
rfx_group_ids_train = rfx_group_ids[train_inds]
292+
rfx_group_ids_test = rfx_group_ids[test_inds]
293+
rfx_basis_train = rfx_basis[train_inds,:]
294+
rfx_basis_test = rfx_basis[test_inds,:]
295+
y_train = y[train_inds]
296+
y_test = y[test_inds]
297+
298+
# Fit a BART model with random intercepts
299+
rfx_params = {"model_spec": "intercept_only"}
300+
bart_model = BARTModel()
301+
bart_model.sample(
302+
X_train=X_train, y_train=y_train, rfx_group_ids_train=rfx_group_ids_train, random_effects_params=rfx_params, num_gfr=10, num_burnin=0, num_mcmc=10
303+
)
304+
305+
# Check that the default predict method returns a dictionary
306+
pred = bart_model.predict(X=X_test, rfx_group_ids=rfx_group_ids_test)
307+
y_hat_posterior_test = pred["y_hat"]
308+
assert y_hat_posterior_test.shape == (20, 10)
309+
310+
# Check that the pre-aggregated predictions match with those computed by np.mean
311+
pred_mean = bart_model.predict(X=X_test, rfx_group_ids=rfx_group_ids_test, type="mean")
312+
y_hat_mean_test = pred_mean["y_hat"]
313+
np.testing.assert_almost_equal(
314+
y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)
315+
)
316+
278317
def test_bcf_prediction(self):
279318
# Generate data and test/train split
280319
rng = np.random.default_rng(1234)
@@ -417,3 +456,72 @@ def g(x5):
417456
np.testing.assert_almost_equal(
418457
sigma2_hat_mean_test, sigma2_hat_mean_test_single_term
419458
)
459+
460+
# Generate data with random effects
461+
rfx_group_ids = rng.choice(3, size=n)
462+
rfx_basis = np.concatenate((np.ones((n, 1)), np.expand_dims(Z, 1)), axis=1)
463+
rfx_coefs = np.array([[-2.0, -0.5], [0.0, 0.0], [2.0, 0.5]])
464+
rfx_term = np.multiply(rfx_coefs[rfx_group_ids,:], rfx_basis).sum(axis=1)
465+
E_XZ = mu_x + tau_x * Z + rfx_term
466+
snr = 2
467+
y = E_XZ + rng.normal(loc=0.0, scale=np.std(E_XZ) / snr, size=(n,))
468+
test_set_pct = 0.2
469+
train_inds, test_inds = train_test_split(
470+
np.arange(n), test_size=test_set_pct, random_state=1234
471+
)
472+
X_train = X.iloc[train_inds, :]
473+
X_test = X.iloc[test_inds, :]
474+
Z_train = Z[train_inds]
475+
Z_test = Z[test_inds]
476+
pi_x_train = pi_x[train_inds]
477+
pi_x_test = pi_x[test_inds]
478+
rfx_group_ids_train = rfx_group_ids[train_inds]
479+
rfx_group_ids_test = rfx_group_ids[test_inds]
480+
rfx_basis_train = rfx_basis[train_inds,:]
481+
rfx_basis_test = rfx_basis[test_inds,:]
482+
y_train = y[train_inds]
483+
y_test = y[test_inds]
484+
485+
# Fit a "classic" BCF model
486+
rfx_params = {"model_spec": "intercept_only"}
487+
bcf_model = BCFModel()
488+
bcf_model.sample(
489+
X_train=X_train,
490+
Z_train=Z_train,
491+
y_train=y_train,
492+
propensity_train=pi_x_train,
493+
rfx_group_ids_train=rfx_group_ids_train,
494+
X_test=X_test,
495+
Z_test=Z_test,
496+
propensity_test=pi_x_test,
497+
rfx_group_ids_test=rfx_group_ids_test,
498+
random_effects_params=rfx_params,
499+
num_gfr=10,
500+
num_burnin=0,
501+
num_mcmc=10,
502+
)
503+
504+
# Check that the default predict method returns a dictionary
505+
pred = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test, rfx_group_ids=rfx_group_ids_test)
506+
y_hat_posterior_test = pred["y_hat"]
507+
assert y_hat_posterior_test.shape == (20, 10)
508+
509+
# Check that the pre-aggregated predictions match with those computed by np.mean
510+
pred_mean = bcf_model.predict(
511+
X=X_test, Z=Z_test, propensity=pi_x_test, rfx_group_ids=rfx_group_ids_test, type="mean"
512+
)
513+
y_hat_mean_test = pred_mean["y_hat"]
514+
np.testing.assert_almost_equal(
515+
y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)
516+
)
517+
518+
# Check that we warn and return None when requesting terms that weren't fit
519+
with pytest.warns(UserWarning):
520+
pred_mean = bcf_model.predict(
521+
X=X_test,
522+
Z=Z_test,
523+
propensity=pi_x_test,
524+
rfx_group_ids=rfx_group_ids_test,
525+
type="mean",
526+
terms=["variance_forest"],
527+
)

0 commit comments

Comments
 (0)