brms integrates with emmeans for marginal mean calculations, but the results do not make sense to me. I am wondering if this is a user error or a bug in brms.
The reprex below uses the mmrm package’s FEV1 dataset, a simulation of a clinical trial with treatment groups in ARMCD and discrete time points for repeated measures in AVISIT. The example compares 4 different methods of estimating marginal means for each combination of ARMCD and AVISIT:
- Data summaries: compute means and independent frequentist 95% confidence intervals on the raw data.
lm()+emmeans: fit a model withlm()and get marginal means withemmeans.brms+ custom: fit a model withbrmsand use a custom linear transformation to map model parameters to marginal means.brms+emmeans: use the nativebrms/emmeansintegration to estimate marginal means from the fittedbrmsmodel.
There is reasonable agreement among approaches (1), (2), and (3), and approach (4) gives very different results from all the others.
suppressPackageStartupMessages({
library(brms)
library(coda)
library(emmeans)
library(mmrm)
library(posterior)
library(tidyverse)
library(zoo)
})
emm_options(sep = "|")
# FEV data from the mmrm package, using LOCF and then LOCF reversed
# to impute responses. (For this discussion, it is helpful to avoid
# the topic of missingness.)
data(fev_data, package = "mmrm")
data <- fev_data %>%
mutate(FEV1_CHG = FEV1 - FEV1_BL, USUBJID = as.character(USUBJID)) %>%
select(-FEV1) %>%
group_by(USUBJID) %>%
complete(
AVISIT,
fill = as.list(.[1L, c("ARMCD", "FEV1_BL", "RACE", "SEX", "WEIGHT")])
) %>%
ungroup() %>%
arrange(USUBJID, AVISIT) %>%
group_by(USUBJID) %>%
mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE)) %>%
mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE, fromLast = TRUE)) %>%
ungroup() %>%
filter(!is.na(FEV1_CHG))
summary_data <- data %>%
group_by(ARMCD, AVISIT) %>%
summarize(
source = "1_data",
mean = mean(FEV1_CHG),
lower = mean(FEV1_CHG) - qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
upper = mean(FEV1_CHG) + qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
.groups = "drop"
)
# Formula shared by all the models
formula <- FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT +
AVISIT + RACE + SEX + WEIGHT
# lm with emmeans
model <- lm(formula = formula, data = data)
summary_lm_emmeans <- emmeans(
object = model,
specs = ~ARMCD:AVISIT,
wt.nuis = "proportional",
nuisance = c("USUBJID", "RACE", "SEX")
) %>%
as.data.frame() %>%
as_tibble() %>%
select(ARMCD, AVISIT, emmean, lower.CL, upper.CL) %>%
rename(mean = emmean, lower = lower.CL, upper = upper.CL) %>%
mutate(source = "2_lm_emmeans")
# brms with emmeans
model <- brm(data = data, formula = brmsformula(formula))
summary_brms_emmeans <- emmeans(
object = model,
specs = ~ARMCD:AVISIT,
wt.nuis = "proportional",
nuisance = c("USUBJID", "RACE", "SEX")
) %>%
as.data.frame() %>%
as_tibble() %>%
select(ARMCD, AVISIT, emmean, lower.HPD, upper.HPD) %>%
rename(mean = emmean, lower = lower.HPD, upper = upper.HPD) %>%
mutate(source = "4_brms_emmeans")
# custom marginal means from brms draws using a custom mapping
# from model parameters to marginal means. This is what I think
# emmeans *should* be doing to a brms model, based on investigations
# using lm() (c.f. https://github.com/openpharma/brms.mmrm/issues/53)
proportional_factors <- brmsformula(FEV1_CHG ~ 0 + SEX + RACE) %>%
make_standata(data = data) %>%
.subset2("X") %>%
colMeans() %>%
t()
grid <- data %>%
mutate(FEV1_BL = mean(FEV1_BL), FEV1_CHG = 0, WEIGHT = mean(WEIGHT)) %>%
distinct(ARMCD, AVISIT, FEV1_BL, WEIGHT, FEV1_CHG)
draws_parameters <- model %>%
as_draws_df() %>%
as_tibble() %>%
select(starts_with("b_"), -starts_with("b_sigma"))
mapping <- brmsformula(
FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT + AVISIT + WEIGHT
) %>%
make_standata(data = grid) %>%
.subset2("X") %>%
bind_cols(proportional_factors) %>%
setNames(paste0("b_", colnames(.)))
stopifnot(all(colnames(draws_parameters) %in% colnames(mapping)))
mapping <- as.matrix(mapping)[, colnames(draws_parameters)]
rownames(mapping) <- paste(grid$ARMCD, grid$AVISIT, sep = "|")
draws_custom <- as.matrix(draws_parameters) %*% t(mapping) %>%
as.data.frame() %>%
as_tibble()
summary_brms_custom <- draws_custom %>%
pivot_longer(everything()) %>%
separate("name", c("ARMCD", "AVISIT")) %>%
group_by(ARMCD, AVISIT) %>%
summarize(
source = "3_brms_custom",
mean = mean(value),
lower = quantile(value, 0.025),
upper = quantile(value, 0.975),
.groups = "drop"
)
# Compare results
summary <- bind_rows(
summary_data,
summary_lm_emmeans,
summary_brms_custom,
summary_brms_emmeans
)
ggplot(summary) +
geom_point(aes(x = source, y = mean, color = source)) +
geom_errorbar(aes(x = source, ymin = lower, ymax = upper, color = source)) +
facet_grid(ARMCD ~ AVISIT) +
theme_gray(16) +
theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) +
ylab("FEV1_CHG")
Related: unless I am missing something, the X matrix of brms:::emm_basis.brmsfit() should map model parameters to marginal means. (At least emmeans:::emm_basis.lm() works this way.) But when I try it, I get the identity matrix:
custom_xlev <- lapply(select(data, AVISIT, ARMCD, SEX, RACE), function(x) unique(as.character(x)))
custom_trms <- terms(x = ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT + AVISIT + RACE + SEX, data = data)
custom_grid <- ref_grid(
model,
specs = ~ARMCD:AVISIT,
nuisance = c("RACE", "SEX")
)
emm_basis(
object = model,
trms = custom_trms,
xlev = custom_xlev,
grid = custom_grid, # I get the same result if I supply custom_grid@grid.
resp = "FEV1_CHG"
)$X
#> [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8]
#> [1,] 1 0 0 0 0 0 0 0
#> [2,] 0 1 0 0 0 0 0 0
#> [3,] 0 0 1 0 0 0 0 0
#> [4,] 0 0 0 1 0 0 0 0
#> [5,] 0 0 0 0 1 0 0 0
#> [6,] 0 0 0 0 0 1 0 0
#> [7,] 0 0 0 0 0 0 1 0
#> [8,] 0 0 0 0 0 0 0 1
Session info:
R version 4.3.2 (2023-10-31)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Sonoma 14.4.1
Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.11.0
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
time zone: America/Indiana/Indianapolis
tzcode source: internal
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] zoo_1.8-12 lubridate_1.9.3 forcats_1.0.0 stringr_1.5.1 dplyr_1.1.4 purrr_1.0.2
[7] readr_2.1.5 tidyr_1.3.1 tibble_3.2.1 ggplot2_3.4.4 tidyverse_2.0.0 posterior_1.5.0
[13] mmrm_0.3.11 emmeans_1.10.0 coda_0.19-4.1 brms_2.21.0 Rcpp_1.0.12
loaded via a namespace (and not attached):
[1] gtable_0.3.4 tensorA_0.36.2.1 QuickJSR_1.1.3 processx_3.8.3
[5] inline_0.3.19 lattice_0.22-5 callr_3.7.3 tzdb_0.4.0
[9] ps_1.7.6 vctrs_0.6.5 tools_4.3.2 Rdpack_2.6
[13] generics_0.1.3 stats4_4.3.2 curl_5.2.0 parallel_4.3.2
[17] sandwich_3.1-0 fansi_1.0.6 pkgconfig_2.0.3 Matrix_1.6-5
[21] checkmate_2.3.1 distributional_0.4.0 RcppParallel_5.1.7 lifecycle_1.0.4
[25] farver_2.1.1 compiler_4.3.2 textshaping_0.3.7 Brobdingnag_1.2-9
[29] munsell_0.5.0 codetools_0.2-19 bayesplot_1.11.0 pillar_1.9.0
[33] MASS_7.3-60.0.1 StanHeaders_2.32.6 bridgesampling_1.1-2 abind_1.4-5
[37] multcomp_1.4-25 nlme_3.1-164 rstan_2.32.6 tidyselect_1.2.0
[41] mvtnorm_1.2-4 stringi_1.8.3 labeling_0.4.3 splines_4.3.2
[45] grid_4.3.2 colorspace_2.1-0 cli_3.6.2 magrittr_2.0.3
[49] loo_2.6.0 survival_3.5-7 pkgbuild_1.4.3 utf8_1.2.4
[53] TH.data_1.1-2 withr_3.0.0 scales_1.3.0 backports_1.4.1
[57] timechange_0.3.0 estimability_1.4.1 matrixStats_1.2.0 gridExtra_2.3
[61] ragg_1.2.7 hms_1.1.3 rbibutils_2.2.16 V8_4.4.1
[65] rstantools_2.4.0 rlang_1.1.3 xtable_1.8-4 glue_1.7.0
[69] rstudioapi_0.15.0 jsonlite_1.8.8 R6_2.5.1 systemfonts_1.0.5
