Trouble with {brms}/{emmeans} integration

brms integrates with emmeans for marginal mean calculations, but the results do not make sense to me. I am wondering if this is a user error or a bug in brms.

The reprex below uses the mmrm package’s FEV1 dataset, a simulation of a clinical trial with treatment groups in ARMCD and discrete time points for repeated measures in AVISIT. The example compares 4 different methods of estimating marginal means for each combination of ARMCD and AVISIT:

  1. Data summaries: compute means and independent frequentist 95% confidence intervals on the raw data.
  2. lm() + emmeans: fit a model with lm() and get marginal means with emmeans.
  3. brms + custom: fit a model with brms and use a custom linear transformation to map model parameters to marginal means.
  4. brms + emmeans: use the native brms/emmeans integration to estimate marginal means from the fitted brms model.

There is reasonable agreement among approaches (1), (2), and (3), and approach (4) gives very different results from all the others.

suppressPackageStartupMessages({
  library(brms)
  library(coda)
  library(emmeans)
  library(mmrm)
  library(posterior)
  library(tidyverse)
  library(zoo)
})
emm_options(sep = "|")

# FEV data from the mmrm package, using LOCF and then LOCF reversed
# to impute responses. (For this discussion, it is helpful to avoid
# the topic of missingness.) 
data(fev_data, package = "mmrm")
data <- fev_data %>%
  mutate(FEV1_CHG = FEV1 - FEV1_BL, USUBJID = as.character(USUBJID)) %>%
  select(-FEV1) %>%
  group_by(USUBJID) %>%
  complete(
    AVISIT,
    fill = as.list(.[1L, c("ARMCD", "FEV1_BL", "RACE", "SEX", "WEIGHT")])
  ) %>%
  ungroup() %>%
  arrange(USUBJID, AVISIT) %>%
  group_by(USUBJID) %>%
  mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE)) %>%
  mutate(FEV1_CHG = na.locf(FEV1_CHG, na.rm = FALSE, fromLast = TRUE)) %>%
  ungroup() %>%
  filter(!is.na(FEV1_CHG))
summary_data <- data %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(
    source = "1_data",
    mean = mean(FEV1_CHG),
    lower = mean(FEV1_CHG) - qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
    upper = mean(FEV1_CHG) + qnorm(0.975) * sd(FEV1_CHG) / sqrt(n()),
    .groups = "drop"
  )

# Formula shared by all the models
formula <- FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT +
  AVISIT + RACE + SEX + WEIGHT

# lm with emmeans
model <- lm(formula = formula, data = data)
summary_lm_emmeans <- emmeans(
  object = model,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
) %>%
  as.data.frame() %>%
  as_tibble() %>%
  select(ARMCD, AVISIT, emmean, lower.CL, upper.CL) %>%
  rename(mean = emmean, lower = lower.CL, upper = upper.CL) %>%
  mutate(source = "2_lm_emmeans")

# brms with emmeans
model <- brm(data = data, formula = brmsformula(formula))
summary_brms_emmeans <- emmeans(
  object = model,
  specs = ~ARMCD:AVISIT,
  wt.nuis = "proportional",
  nuisance = c("USUBJID", "RACE", "SEX")
) %>%
  as.data.frame() %>%
  as_tibble() %>%
  select(ARMCD, AVISIT, emmean, lower.HPD, upper.HPD) %>%
  rename(mean = emmean, lower = lower.HPD, upper = upper.HPD) %>%
  mutate(source = "4_brms_emmeans")

# custom marginal means from brms draws using a custom mapping
# from model parameters to marginal means. This is what I think
# emmeans *should* be doing to a brms model, based on investigations
# using lm() (c.f. https://github.com/openpharma/brms.mmrm/issues/53)
proportional_factors <- brmsformula(FEV1_CHG ~ 0 + SEX + RACE) %>%
  make_standata(data = data) %>%
  .subset2("X") %>%
  colMeans() %>%
  t()
grid <- data %>%
  mutate(FEV1_BL = mean(FEV1_BL), FEV1_CHG = 0, WEIGHT = mean(WEIGHT)) %>%
  distinct(ARMCD, AVISIT, FEV1_BL, WEIGHT, FEV1_CHG)
draws_parameters <- model %>%
  as_draws_df() %>%
  as_tibble() %>%
  select(starts_with("b_"), -starts_with("b_sigma"))
mapping <- brmsformula(
    FEV1_CHG ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT + AVISIT + WEIGHT
  ) %>%
  make_standata(data = grid) %>%
  .subset2("X") %>%
  bind_cols(proportional_factors) %>%
  setNames(paste0("b_", colnames(.)))
stopifnot(all(colnames(draws_parameters) %in% colnames(mapping)))
mapping <- as.matrix(mapping)[, colnames(draws_parameters)]
rownames(mapping) <- paste(grid$ARMCD, grid$AVISIT, sep = "|")
draws_custom <- as.matrix(draws_parameters) %*% t(mapping) %>%
  as.data.frame() %>%
  as_tibble()
summary_brms_custom <- draws_custom %>%
  pivot_longer(everything()) %>%
  separate("name", c("ARMCD", "AVISIT")) %>%
  group_by(ARMCD, AVISIT) %>%
  summarize(
    source = "3_brms_custom",
    mean = mean(value),
    lower = quantile(value, 0.025),
    upper = quantile(value, 0.975),
    .groups = "drop"
  )

# Compare results
summary <- bind_rows(
  summary_data,
  summary_lm_emmeans,
  summary_brms_custom,
  summary_brms_emmeans
)
ggplot(summary) +
  geom_point(aes(x = source, y = mean, color = source)) +
  geom_errorbar(aes(x = source, ymin = lower, ymax = upper, color = source)) +
  facet_grid(ARMCD ~ AVISIT) +
  theme_gray(16) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) +
  ylab("FEV1_CHG")

Related: unless I am missing something, the X matrix of brms:::emm_basis.brmsfit() should map model parameters to marginal means. (At least emmeans:::emm_basis.lm() works this way.) But when I try it, I get the identity matrix:

custom_xlev <- lapply(select(data, AVISIT, ARMCD, SEX, RACE), function(x) unique(as.character(x)))
custom_trms <- terms(x = ~ FEV1_BL + FEV1_BL:AVISIT + ARMCD + ARMCD:AVISIT + AVISIT + RACE + SEX, data = data)
custom_grid <- ref_grid(
  model,
  specs = ~ARMCD:AVISIT,
  nuisance = c("RACE", "SEX")
)
emm_basis(
  object = model,
  trms = custom_trms,
  xlev = custom_xlev,
  grid = custom_grid, # I get the same result if I supply custom_grid@grid.
  resp = "FEV1_CHG"
)$X
#>      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8]
#> [1,]    1    0    0    0    0    0    0    0
#> [2,]    0    1    0    0    0    0    0    0
#> [3,]    0    0    1    0    0    0    0    0
#> [4,]    0    0    0    1    0    0    0    0
#> [5,]    0    0    0    0    1    0    0    0
#> [6,]    0    0    0    0    0    1    0    0
#> [7,]    0    0    0    0    0    0    1    0
#> [8,]    0    0    0    0    0    0    0    1

Session info:

R version 4.3.2 (2023-10-31)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Sonoma 14.4.1

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Indiana/Indianapolis
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] zoo_1.8-12      lubridate_1.9.3 forcats_1.0.0   stringr_1.5.1   dplyr_1.1.4     purrr_1.0.2    
 [7] readr_2.1.5     tidyr_1.3.1     tibble_3.2.1    ggplot2_3.4.4   tidyverse_2.0.0 posterior_1.5.0
[13] mmrm_0.3.11     emmeans_1.10.0  coda_0.19-4.1   brms_2.21.0     Rcpp_1.0.12    

loaded via a namespace (and not attached):
 [1] gtable_0.3.4         tensorA_0.36.2.1     QuickJSR_1.1.3       processx_3.8.3      
 [5] inline_0.3.19        lattice_0.22-5       callr_3.7.3          tzdb_0.4.0          
 [9] ps_1.7.6             vctrs_0.6.5          tools_4.3.2          Rdpack_2.6          
[13] generics_0.1.3       stats4_4.3.2         curl_5.2.0           parallel_4.3.2      
[17] sandwich_3.1-0       fansi_1.0.6          pkgconfig_2.0.3      Matrix_1.6-5        
[21] checkmate_2.3.1      distributional_0.4.0 RcppParallel_5.1.7   lifecycle_1.0.4     
[25] farver_2.1.1         compiler_4.3.2       textshaping_0.3.7    Brobdingnag_1.2-9   
[29] munsell_0.5.0        codetools_0.2-19     bayesplot_1.11.0     pillar_1.9.0        
[33] MASS_7.3-60.0.1      StanHeaders_2.32.6   bridgesampling_1.1-2 abind_1.4-5         
[37] multcomp_1.4-25      nlme_3.1-164         rstan_2.32.6         tidyselect_1.2.0    
[41] mvtnorm_1.2-4        stringi_1.8.3        labeling_0.4.3       splines_4.3.2       
[45] grid_4.3.2           colorspace_2.1-0     cli_3.6.2            magrittr_2.0.3      
[49] loo_2.6.0            survival_3.5-7       pkgbuild_1.4.3       utf8_1.2.4          
[53] TH.data_1.1-2        withr_3.0.0          scales_1.3.0         backports_1.4.1     
[57] timechange_0.3.0     estimability_1.4.1   matrixStats_1.2.0    gridExtra_2.3       
[61] ragg_1.2.7           hms_1.1.3            rbibutils_2.2.16     V8_4.4.1            
[65] rstantools_2.4.0     rlang_1.1.3          xtable_1.8-4         glue_1.7.0          
[69] rstudioapi_0.15.0    jsonlite_1.8.8       R6_2.5.1             systemfonts_1.0.5 
2 Likes