brms
integrates with emmeans
for marginal mean calculations, but the results do not make sense to me. I am wondering if this is a user error or a bug in brms
.
The reprex below uses the mmrm
package’s FEV1 dataset, a simulation of a clinical trial with treatment groups in ARMCD
and discrete time points for repeated measures in AVISIT
. The example compares 4 different methods of estimating marginal means for each combination of ARMCD
and AVISIT
:
- Data summaries: compute means and independent frequentist 95% confidence intervals on the raw data.
lm()
+emmeans
: fit a model withlm()
and get marginal means withemmeans
.brms
+ custom: fit a model withbrms
and use a custom linear transformation to map model parameters to marginal means.brms
+emmeans
: use the nativebrms
/emmeans
integration to estimate marginal means from the fittedbrms
model.
There is reasonable agreement among approaches (1), (2), and (3), and approach (4) gives very different results from all the others.
suppressPackageStartupMessages({
library(brms)
library(coda)
library(emmeans)
library(mmrm)
library(posterior)
library(tidyverse)
library(zoo)
})
emm_options(sep = "|")
# FEV data from the mmrm package, using LOCF and then LOCF reversed
# to impute responses. (For this discussion, it is helpful to avoid
# the topic of missingness.)
data(fev_data, package = "mmrm")
data <- fev_data %>%
mutate(FEV1_CHG = FEV1 - FEV1_BL, USUBJID = as.character(USUBJID)) %>%
select(-FEV1) %>%
group_by(USUBJID) %>%
complete(
AVISIT,
fill = as.list(.[1L, c("ARMCD", "FEV1_BL", "RACE", "SEX", "WEIGHT")])
) %>%
ungroup() %>%
arrange(USUBJID, AVISIT) %>%
group_by(USUBJID) %>%
mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE)) %>%
mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE, fromLast = TRUE)) %>%
ungroup() %>%
filter(!is.na(FEV1_CHG))
summary_data <- data %>%
group_by(ARMCD, AVISIT) %>%
summarize(
source = "1_data",
mean = mean(FEV1_CHG),
lower = mean(FEV1_CHG) - qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
upper = mean(FEV1_CHG) + qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
.groups = "drop"
)
# Formula shared by all the models
formula <- FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT +
AVISIT + RACE + SEX + WEIGHT
# lm with emmeans
model <- lm(formula = formula, data = data)
summary_lm_emmeans <- emmeans(
object = model,
specs = ~ARMCD:AVISIT,
wt.nuis = "proportional",
nuisance = c("USUBJID", "RACE", "SEX")
) %>%
as.data.frame() %>%
as_tibble() %>%
select(ARMCD, AVISIT, emmean, lower.CL, upper.CL) %>%
rename(mean = emmean, lower = lower.CL, upper = upper.CL) %>%
mutate(source = "2_lm_emmeans")
# brms with emmeans
model <- brm(data = data, formula = brmsformula(formula))
summary_brms_emmeans <- emmeans(
object = model,
specs = ~ARMCD:AVISIT,
wt.nuis = "proportional",
nuisance = c("USUBJID", "RACE", "SEX")
) %>%
as.data.frame() %>%
as_tibble() %>%
select(ARMCD, AVISIT, emmean, lower.HPD, upper.HPD) %>%
rename(mean = emmean, lower = lower.HPD, upper = upper.HPD) %>%
mutate(source = "4_brms_emmeans")
# custom marginal means from brms draws using a custom mapping
# from model parameters to marginal means. This is what I think
# emmeans *should* be doing to a brms model, based on investigations
# using lm() (c.f. https://github.com/openpharma/brms.mmrm/issues/53)
proportional_factors <- brmsformula(FEV1_CHG ~ 0 + SEX + RACE) %>%
make_standata(data = data) %>%
.subset2("X") %>%
colMeans() %>%
t()
grid <- data %>%
mutate(FEV1_BL = mean(FEV1_BL), FEV1_CHG = 0, WEIGHT = mean(WEIGHT)) %>%
distinct(ARMCD, AVISIT, FEV1_BL, WEIGHT, FEV1_CHG)
draws_parameters <- model %>%
as_draws_df() %>%
as_tibble() %>%
select(starts_with("b_"), -starts_with("b_sigma"))
mapping <- brmsformula(
FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT + AVISIT + WEIGHT
) %>%
make_standata(data = grid) %>%
.subset2("X") %>%
bind_cols(proportional_factors) %>%
setNames(paste0("b_", colnames(.)))
stopifnot(all(colnames(draws_parameters) %in% colnames(mapping)))
mapping <- as.matrix(mapping)[, colnames(draws_parameters)]
rownames(mapping) <- paste(grid$ARMCD, grid$AVISIT, sep = "|")
draws_custom <- as.matrix(draws_parameters) %*% t(mapping) %>%
as.data.frame() %>%
as_tibble()
summary_brms_custom <- draws_custom %>%
pivot_longer(everything()) %>%
separate("name", c("ARMCD", "AVISIT")) %>%
group_by(ARMCD, AVISIT) %>%
summarize(
source = "3_brms_custom",
mean = mean(value),
lower = quantile(value, 0.025),
upper = quantile(value, 0.975),
.groups = "drop"
)
# Compare results
summary <- bind_rows(
summary_data,
summary_lm_emmeans,
summary_brms_custom,
summary_brms_emmeans
)
ggplot(summary) +
geom_point(aes(x = source, y = mean, color = source)) +
geom_errorbar(aes(x = source, ymin = lower, ymax = upper, color = source)) +
facet_grid(ARMCD ~ AVISIT) +
theme_gray(16) +
theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) +
ylab("FEV1_CHG")
Related: unless I am missing something, the X
matrix of brms:::emm_basis.brmsfit()
should map model parameters to marginal means. (At least emmeans:::emm_basis.lm()
works this way.) But when I try it, I get the identity matrix:
custom_xlev <- lapply(select(data, AVISIT, ARMCD, SEX, RACE), function(x) unique(as.character(x)))
custom_trms <- terms(x = ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT + AVISIT + RACE + SEX, data = data)
custom_grid <- ref_grid(
model,
specs = ~ARMCD:AVISIT,
nuisance = c("RACE", "SEX")
)
emm_basis(
object = model,
trms = custom_trms,
xlev = custom_xlev,
grid = custom_grid, # I get the same result if I supply custom_grid@grid.
resp = "FEV1_CHG"
)$X
#> [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8]
#> [1,] 1 0 0 0 0 0 0 0
#> [2,] 0 1 0 0 0 0 0 0
#> [3,] 0 0 1 0 0 0 0 0
#> [4,] 0 0 0 1 0 0 0 0
#> [5,] 0 0 0 0 1 0 0 0
#> [6,] 0 0 0 0 0 1 0 0
#> [7,] 0 0 0 0 0 0 1 0
#> [8,] 0 0 0 0 0 0 0 1
Session info:
R version 4.3.2 (2023-10-31)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Sonoma 14.4.1
Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.11.0
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
time zone: America/Indiana/Indianapolis
tzcode source: internal
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] zoo_1.8-12 lubridate_1.9.3 forcats_1.0.0 stringr_1.5.1 dplyr_1.1.4 purrr_1.0.2
[7] readr_2.1.5 tidyr_1.3.1 tibble_3.2.1 ggplot2_3.4.4 tidyverse_2.0.0 posterior_1.5.0
[13] mmrm_0.3.11 emmeans_1.10.0 coda_0.19-4.1 brms_2.21.0 Rcpp_1.0.12
loaded via a namespace (and not attached):
[1] gtable_0.3.4 tensorA_0.36.2.1 QuickJSR_1.1.3 processx_3.8.3
[5] inline_0.3.19 lattice_0.22-5 callr_3.7.3 tzdb_0.4.0
[9] ps_1.7.6 vctrs_0.6.5 tools_4.3.2 Rdpack_2.6
[13] generics_0.1.3 stats4_4.3.2 curl_5.2.0 parallel_4.3.2
[17] sandwich_3.1-0 fansi_1.0.6 pkgconfig_2.0.3 Matrix_1.6-5
[21] checkmate_2.3.1 distributional_0.4.0 RcppParallel_5.1.7 lifecycle_1.0.4
[25] farver_2.1.1 compiler_4.3.2 textshaping_0.3.7 Brobdingnag_1.2-9
[29] munsell_0.5.0 codetools_0.2-19 bayesplot_1.11.0 pillar_1.9.0
[33] MASS_7.3-60.0.1 StanHeaders_2.32.6 bridgesampling_1.1-2 abind_1.4-5
[37] multcomp_1.4-25 nlme_3.1-164 rstan_2.32.6 tidyselect_1.2.0
[41] mvtnorm_1.2-4 stringi_1.8.3 labeling_0.4.3 splines_4.3.2
[45] grid_4.3.2 colorspace_2.1-0 cli_3.6.2 magrittr_2.0.3
[49] loo_2.6.0 survival_3.5-7 pkgbuild_1.4.3 utf8_1.2.4
[53] TH.data_1.1-2 withr_3.0.0 scales_1.3.0 backports_1.4.1
[57] timechange_0.3.0 estimability_1.4.1 matrixStats_1.2.0 gridExtra_2.3
[61] ragg_1.2.7 hms_1.1.3 rbibutils_2.2.16 V8_4.4.1
[65] rstantools_2.4.0 rlang_1.1.3 xtable_1.8-4 glue_1.7.0
[69] rstudioapi_0.15.0 jsonlite_1.8.8 R6_2.5.1 systemfonts_1.0.5