Skip to content

Commit 0e18190

Browse files
committed
Merge branch 'main' into release-news-update
2 parents 3d36699 + 36d59c3 commit 0e18190

File tree

14 files changed

+749
-451
lines changed

14 files changed

+749
-451
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: stochtree
22
Title: Stochastic Tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference
3-
Version: 0.1.1
3+
Version: 0.2.0
44
Authors@R:
55
c(
66
person("Drew", "Herren", email = "drewherrenopensource@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")),

Doxyfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ PROJECT_NAME = "StochTree"
4848
# could be handy for archiving the generated documentation or if some version
4949
# control system is used.
5050

51-
PROJECT_NUMBER = 0.1.1
51+
PROJECT_NUMBER = 0.2.0
5252

5353
# Using the PROJECT_BRIEF tag one can provide an optional one line description
5454
# for a project that appears at the top of each page and should give viewer a

R/bart.R

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -423,37 +423,53 @@ bart <- function(
423423
floor(num_values / cutpoint_grid_size),
424424
1
425425
)
426+
x_is_df <- is.data.frame(X_train)
426427
covs_warning_1 <- NULL
427428
covs_warning_2 <- NULL
428429
covs_warning_3 <- NULL
430+
covs_warning_4 <- NULL
429431
for (i in 1:num_cov_orig) {
430-
# Determine the number of unique values
431-
num_unique_values <- length(unique(X_train[, i]))
432-
433-
# Determine a "name" for the covariate
434-
cov_name <- ifelse(
435-
is.null(colnames(X_train)),
436-
paste0("X", i),
437-
colnames(X_train)[i]
438-
)
439-
440-
# Check for a small relative number of unique values
441-
unique_full_ratio <- num_unique_values / num_values
442-
if (unique_full_ratio < 0.2) {
443-
covs_warning_1 <- c(covs_warning_1, cov_name)
432+
# Skip check for variables that are treated as categorical
433+
x_numeric <- T
434+
if (x_is_df) {
435+
if (is.factor(X_train[, i])) {
436+
x_numeric <- F
437+
}
444438
}
439+
if (x_numeric) {
440+
# Determine the number of unique values
441+
num_unique_values <- length(unique(X_train[, i]))
442+
443+
# Determine a "name" for the covariate
444+
cov_name <- ifelse(
445+
is.null(colnames(X_train)),
446+
paste0("X", i),
447+
colnames(X_train)[i]
448+
)
445449

446-
# Check for a small absolute number of unique values
447-
if (num_values > 100) {
448-
if (num_unique_values < 20) {
449-
covs_warning_2 <- c(covs_warning_2, cov_name)
450+
# Check for a small relative number of unique values
451+
unique_full_ratio <- num_unique_values / num_values
452+
if (unique_full_ratio < 0.2) {
453+
covs_warning_1 <- c(covs_warning_1, cov_name)
454+
}
455+
456+
# Check for a small absolute number of unique values
457+
if (num_values > 100) {
458+
if (num_unique_values < 20) {
459+
covs_warning_2 <- c(covs_warning_2, cov_name)
460+
}
461+
}
462+
463+
# Check for a large number of duplicates of any individual value
464+
x_j_hist <- table(X_train[, i])
465+
if (any(x_j_hist > 2 * max_grid_size)) {
466+
covs_warning_3 <- c(covs_warning_3, cov_name)
450467
}
451-
}
452468

453-
# Check for a large number of duplicates of any individual value
454-
x_j_hist <- table(X_train[, i])
455-
if (any(x_j_hist > 2 * max_grid_size)) {
456-
covs_warning_3 <- c(covs_warning_3, cov_name)
469+
# Check for binary variables
470+
if (num_unique_values == 2) {
471+
covs_warning_4 <- c(covs_warning_4, cov_name)
472+
}
457473
}
458474
}
459475

@@ -494,6 +510,18 @@ bart <- function(
494510
)
495511
)
496512
}
513+
514+
if (!is.null(covs_warning_4)) {
515+
warning(
516+
paste0(
517+
"Covariates ",
518+
paste(covs_warning_4, collapse = ", "),
519+
" appear to be binary but are currently treated by stochtree as continuous. ",
520+
"This might present some issues with the grow-from-root (GFR) algorithm. ",
521+
"Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
522+
)
523+
)
524+
}
497525
}
498526

499527
# Standardize the keep variable lists to numeric indices

R/bcf.R

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -527,37 +527,54 @@ bcf <- function(
527527
floor(num_values / cutpoint_grid_size),
528528
1
529529
)
530+
x_is_df <- is.data.frame(X_train)
530531
covs_warning_1 <- NULL
531532
covs_warning_2 <- NULL
532533
covs_warning_3 <- NULL
534+
covs_warning_4 <- NULL
533535
for (i in 1:num_cov_orig) {
534-
# Determine the number of unique values
535-
num_unique_values <- length(unique(X_train[, i]))
536-
537-
# Determine a "name" for the covariate
538-
cov_name <- ifelse(
539-
is.null(colnames(X_train)),
540-
paste0("X", i),
541-
colnames(X_train)[i]
542-
)
543-
544-
# Check for a small relative number of unique values
545-
unique_full_ratio <- num_unique_values / num_values
546-
if (unique_full_ratio < 0.2) {
547-
covs_warning_1 <- c(covs_warning_1, cov_name)
536+
# Skip check for variables that are treated as categorical
537+
x_numeric <- T
538+
if (x_is_df) {
539+
if (is.factor(X_train[, i])) {
540+
x_numeric <- F
541+
}
548542
}
549543

550-
# Check for a small absolute number of unique values
551-
if (num_values > 100) {
552-
if (num_unique_values < 20) {
553-
covs_warning_2 <- c(covs_warning_2, cov_name)
544+
if (x_numeric) {
545+
# Determine the number of unique values
546+
num_unique_values <- length(unique(X_train[, i]))
547+
548+
# Determine a "name" for the covariate
549+
cov_name <- ifelse(
550+
is.null(colnames(X_train)),
551+
paste0("X", i),
552+
colnames(X_train)[i]
553+
)
554+
555+
# Check for a small relative number of unique values
556+
unique_full_ratio <- num_unique_values / num_values
557+
if (unique_full_ratio < 0.2) {
558+
covs_warning_1 <- c(covs_warning_1, cov_name)
554559
}
555-
}
556560

557-
# Check for a large number of duplicates of any individual value
558-
x_j_hist <- table(X_train[, i])
559-
if (any(x_j_hist > 2 * max_grid_size)) {
560-
covs_warning_3 <- c(covs_warning_3, cov_name)
561+
# Check for a small absolute number of unique values
562+
if (num_values > 100) {
563+
if (num_unique_values < 20) {
564+
covs_warning_2 <- c(covs_warning_2, cov_name)
565+
}
566+
}
567+
568+
# Check for a large number of duplicates of any individual value
569+
x_j_hist <- table(X_train[, i])
570+
if (any(x_j_hist > 2 * max_grid_size)) {
571+
covs_warning_3 <- c(covs_warning_3, cov_name)
572+
}
573+
574+
# Check for binary variables
575+
if (num_unique_values == 2) {
576+
covs_warning_4 <- c(covs_warning_4, cov_name)
577+
}
561578
}
562579
}
563580

@@ -598,6 +615,18 @@ bcf <- function(
598615
)
599616
)
600617
}
618+
619+
if (!is.null(covs_warning_4)) {
620+
warning(
621+
paste0(
622+
"Covariates ",
623+
paste(covs_warning_4, collapse = ", "),
624+
" appear to be binary but are currently treated by stochtree as continuous. ",
625+
"This might present some issues with the grow-from-root (GFR) algorithm. ",
626+
"Consider converting binary variables to ordered factor (i.e. `factor(..., ordered = T)`."
627+
)
628+
)
629+
}
601630
}
602631

603632
# Check delta_max is valid

cleanup

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/bin/sh
2+
rm -f src/Makevars

configure

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#! /bin/sh
22
# Guess values for system-dependent variables and create Makefiles.
3-
# Generated by GNU Autoconf 2.72 for stochtree 0.1.1.
3+
# Generated by GNU Autoconf 2.72 for stochtree 0.2.0.
44
#
55
#
66
# Copyright (C) 1992-1996, 1998-2017, 2020-2023 Free Software Foundation,
@@ -600,8 +600,8 @@ MAKEFLAGS=
600600
# Identity of this package.
601601
PACKAGE_NAME='stochtree'
602602
PACKAGE_TARNAME='stochtree'
603-
PACKAGE_VERSION='0.1.1'
604-
PACKAGE_STRING='stochtree 0.1.1'
603+
PACKAGE_VERSION='0.2.0'
604+
PACKAGE_STRING='stochtree 0.2.0'
605605
PACKAGE_BUGREPORT=''
606606
PACKAGE_URL=''
607607

@@ -1205,7 +1205,7 @@ if test "$ac_init_help" = "long"; then
12051205
# Omit some internal or obsolete options to make the list less imposing.
12061206
# This message is too long to be a string in the A/UX 3.1 sh.
12071207
cat <<_ACEOF
1208-
'configure' configures stochtree 0.1.1 to adapt to many kinds of systems.
1208+
'configure' configures stochtree 0.2.0 to adapt to many kinds of systems.
12091209
12101210
Usage: $0 [OPTION]... [VAR=VALUE]...
12111211
@@ -1267,7 +1267,7 @@ fi
12671267

12681268
if test -n "$ac_init_help"; then
12691269
case $ac_init_help in
1270-
short | recursive ) echo "Configuration of stochtree 0.1.1:";;
1270+
short | recursive ) echo "Configuration of stochtree 0.2.0:";;
12711271
esac
12721272
cat <<\_ACEOF
12731273
@@ -1335,7 +1335,7 @@ fi
13351335
test -n "$ac_init_help" && exit $ac_status
13361336
if $ac_init_version; then
13371337
cat <<\_ACEOF
1338-
stochtree configure 0.1.1
1338+
stochtree configure 0.2.0
13391339
generated by GNU Autoconf 2.72
13401340
13411341
Copyright (C) 2023 Free Software Foundation, Inc.
@@ -1372,7 +1372,7 @@ cat >config.log <<_ACEOF
13721372
This file contains any messages produced by compilers while
13731373
running configure, to aid debugging if configure makes a mistake.
13741374
1375-
It was created by stochtree $as_me 0.1.1, which was
1375+
It was created by stochtree $as_me 0.2.0, which was
13761376
generated by GNU Autoconf 2.72. Invocation command line was
13771377
13781378
$ $0$ac_configure_args_raw
@@ -2385,7 +2385,7 @@ cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1
23852385
# report actual input values of CONFIG_FILES etc. instead of their
23862386
# values after options handling.
23872387
ac_log="
2388-
This file was extended by stochtree $as_me 0.1.1, which was
2388+
This file was extended by stochtree $as_me 0.2.0, which was
23892389
generated by GNU Autoconf 2.72. Invocation command line was
23902390
23912391
CONFIG_FILES = $CONFIG_FILES
@@ -2440,7 +2440,7 @@ ac_cs_config_escaped=`printf "%s\n" "$ac_cs_config" | sed "s/^ //; s/'/'\\\\\\\\
24402440
cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1
24412441
ac_cs_config='$ac_cs_config_escaped'
24422442
ac_cs_version="\\
2443-
stochtree config.status 0.1.1
2443+
stochtree config.status 0.2.0
24442444
configured by $0, generated by GNU Autoconf 2.72,
24452445
with options \\"\$ac_cs_config\\"
24462446
@@ -3000,3 +3000,4 @@ if test -n "$ac_unrecognized_opts" && test "$enable_option_checking" != no; then
30003000
printf "%s\n" "$as_me: WARNING: unrecognized options: $ac_unrecognized_opts" >&2;}
30013001
fi
30023002

3003+

configure.ac

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# https://github.com/microsoft/LightGBM/blob/master/R-package/configure.ac
44

55
AC_PREREQ(2.69)
6-
AC_INIT([stochtree], [0.1.1], [], [stochtree], [])
6+
AC_INIT([stochtree], [0.2.0], [], [stochtree], [])
77
# Note: consider making version number dynamic as in
88
# https://github.com/microsoft/LightGBM/blob/195c26fc7b00eb0fec252dfe841e2e66d6833954/build-cran-package.sh
99

0 commit comments

Comments
 (0)