Trouble validating a Bayesian MMRM implemented with {brms}

In the openstatsware group (formerly the ASA Biopharmaceutical Software Engineering Working Group), we are developing an R package to fit Bayesian mixed models of repeated measures (MMRMs): GitHub - openpharma/brms.mmrm: R package to run Bayesian MMRMs using {brms}. Our package uses brms in the backend, and its goal is to simplify MMRM modeling using in an interface tailored to the needs of clinical trial data analysis.

We are trying to validate our implementation using simulation-based calibration (SBC). Initially we used brms::brm(sample_prior = "only") to simulate from the prior, and the rank statistics all looked uniform as we had hoped. But we do not want to stop there because it is arguably circular to simulate using brms and then analyze with brms. For extra assurance, and to convince ourselves that we understand what brms is doing, we are developing custom R code to simulate from the prior. And that is where things are not going smoothly.

The reprex below runs an MMRM tailored to clinical trial data. We have independent patients who were randomized to 3 different treatment groups, and each patient is observed in 4 different discrete time points (scheduled study visits). The model formula uses treatment group levels as fixed effects (with no intercept), and the residuals within each patient have an unstructured (fully parameterized) correlation matrix. The correlation is modeled with a separation strategy and an LKJ prior.

brmsformula(
  formula = response ~ 0 + group + unstr(time = time, gr = patient),
  sigma ~ 0 + time
)

Oddly enough, the intercept-only equivalent appears well calibrated (response ~ unstr(time = time, gr = patient)), and so does a model with independent residuals (response ~ 0 + group). That alone makes me think the culprit is some sort of non-identifiability, but to me all the parameters seem like they should be easy to estimate.

Anyway, here is the code. The simulations unfortunately take a long time to run, especially for a reprex. I am using brms 2.20.4, rstan 2.32.3, and StanHeaders 2.26.28. See the very bottom for the rank statistics histograms and session info.

library(brms)
library(dplyr)
library(tibble)
library(tidyr)

#############
# FUNCTIONS #
#############

one_sbc_replication <- function(
  chains = 4L,
  warmup = 2000L,
  iter = 4000L
) {
  prior <- set_prior("lkj_corr_cholesky(1)", class = "Lcortime") +
     set_prior("normal(0, 1)", class = "b", coef = "groupgroup_1") +
     set_prior("normal(0, 1)", class = "b", coef = "groupgroup_2") +
     set_prior("normal(0, 1)", class = "b", coef = "groupgroup_3") +
     set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_1") +
     set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_2") +
     set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_3") +
     set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_4")
  formula <- brmsformula(
    formula = response ~ 0 + group + unstr(time = time, gr = patient),
    sigma ~ 0 + time
  )
  simulation <- simulate_data(formula = formula, prior = prior)
  options(brms.backend = "rstan")
  model <- brm(
    data = simulation$data,
    formula = formula,
    prior = prior,
    chains = chains,
    cores = chains,
    iter = iter,
    warmup = warmup
  )
  get_sbc_ranks(model, simulation)
}

simulate_data <- function(formula, prior) {
  n_group <- 3L
  n_patient <- 100L
  n_time <- 4L
  patients <- tibble(
    group = paste0("group_", rep(seq_len(n_group), each = n_patient)),
    patient = paste0("patient_", seq_len(n_group * n_patient))
  )
  data <- expand_grid(patients, time = paste0("time_", seq_len(n_time)))
  data$response <- 0
  x <- make_standata(formula, data, prior = prior)$X
  beta <- rnorm(n = n_group, mean = 0, sd = 1)
  names(beta) <- paste0("b_", colnames(x))
  b_sigma <- rnorm(n = n_time, mean = 0, sd = 1)
  names(b_sigma) <- paste0("b_sigma_timetime_", seq_len(n_time))
  sigma <- exp(b_sigma)
  correlation <- trialr::rlkjcorr(n = 1L, K = n_time, eta = 1)
  i <- rep(seq_len(n_time), each = n_time)
  j <- rep(seq_len(n_time), times = n_time)
  cortime <- as.numeric(correlation)[j > i]
  names(cortime) <- sprintf("cortime__time_%s__time_%s", i[j > i], j[j > i])
  covariance <- diag(sigma) %*% correlation %*% diag(sigma)
  data <- data |>
    mutate(mu = as.numeric(x %*% beta)) |>
    mutate(index_patient = rep(seq_len(n_patient * n_group), each = n_time)) |>
    group_by(index_patient) |>
    mutate(response = MASS::mvrnorm(mu = mu, Sigma = covariance)) |>
    ungroup() |>
    select(-index_patient, -mu)
  parameters <- c(beta, b_sigma, cortime)
  stopifnot(!anyDuplicated(names(parameters)))
  list(data = data, parameters = parameters)
}

get_sbc_ranks <- function(model, simulation) {
  draws <- posterior::as_draws_matrix(model)
  draws <- draws[, setdiff(colnames(draws), c("lprior", "lp__"))]
  truth <- simulation$parameters
  stopifnot(all(sort(names(truth)) == sort(colnames(draws))))
  draws <- draws[, names(truth)]
  ranks <- SBC::calculate_ranks_draws_matrix(variables = truth, dm = draws)
  tibble::as_tibble(as.list(ranks))
}

##############
# SIMULATION #
##############

# I used an SGE cluster, so this is how I set up the crew controller:
controller <- crew.cluster::crew_controller_sge(
  name = "brms-mmrm-sbc",
  workers = 100L,
  seconds_idle = 30,
  seconds_launch = 1800,
  launch_max = 3L,
  script_lines = "module load R/4.2.2",
  sge_cores = 4L
)

# But if you have different resources, you may want to choose
# a different crew launcher plugin, e.g.:
# controller <- crew::crew_controller_local()

# Run the simulations:
controller$start()
tasks <- controller$map(
  command = one_sbc_replication(chains = 4L, warmup = 2000L, iter = 4000L),
  iterate = list(index = seq_len(100L)),
  globals = as.list(globalenv()),
  packages = c("brms", "dplyr", "tibble", "tidyr")
)
controller$terminate()
simulations <- bind_rows(tasks$result)

###########
# RESULTS #
###########

library(tidyr)
results <- pivot_longer(
  simulations,
  cols = everything(),
  names_to = "parameter",
  values_to = "rank"
)

library(ggplot2)
plot <- ggplot(results) +
  geom_histogram(
    aes(x = rank),
    breaks = seq(from = 0, to = max(results$rank), length.out = 10)
  ) +
  facet_wrap(~parameter) +
  theme_gray(16)
ggsave("plot.png", plot, width = 12, height = 10)

The rank statistic histograms in the plot below are far from uniform.

Session info:

R version 4.2.2 Patched (2022-11-30 r83413)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Red Hat Enterprise Linux

Matrix products: default
BLAS/LAPACK: /CENSORED/intel-2020/compilers_and_libraries_2020.0.166/linux/mkl/lib/intel64_lin/libmkl_gf_lp64.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] ggplot2_3.4.0 tidyr_1.2.1   tibble_3.2.1  dplyr_1.1.3   brms_2.20.4  
[6] Rcpp_1.0.9   

loaded via a namespace (and not attached):
  [1] nlme_3.1-160         matrixStats_0.63.0   xts_0.12.2          
  [4] threejs_0.3.3        rstan_2.32.3         tensorA_0.36.2      
  [7] tools_4.2.2          backports_1.4.1      utf8_1.2.3          
 [10] R6_2.5.1             DT_0.26              ggdist_3.2.0        
 [13] colorspace_2.0-3     withr_2.5.1          tidyselect_1.2.0    
 [16] gridExtra_2.3        prettyunits_1.1.1    processx_3.8.2      
 [19] Brobdingnag_1.2-9    emmeans_1.8.8        curl_4.3.3          
 [22] compiler_4.2.2       cli_3.6.1            binom_1.1-1.1       
 [25] arrayhelpers_1.1-0   shinyjs_2.1.0        sandwich_3.0-2      
 [28] colourpicker_1.2.0   posterior_1.4.1      scales_1.2.1        
 [31] dygraphs_1.1.1.6     checkmate_2.1.0      mvtnorm_1.1-3       
 [34] callr_3.7.3          QuickJSR_1.0.5       stringr_1.5.0       
 [37] digest_0.6.33        StanHeaders_2.26.28  base64enc_0.1-3     
 [40] pkgconfig_2.0.3      htmltools_0.5.4      fastmap_1.1.0       
 [43] htmlwidgets_1.5.4    rlang_1.1.1          shiny_1.7.3         
 [46] svUnit_1.0.6         farver_2.1.1         generics_0.1.3      
 [49] zoo_1.8-11           jsonlite_1.8.4       crosstalk_1.2.0     
 [52] gtools_3.9.4         distributional_0.3.1 inline_0.3.19       
 [55] magrittr_2.0.3       loo_2.5.1            bayesplot_1.10.0    
 [58] Matrix_1.5-3         munsell_0.5.0        fansi_1.0.5         
 [61] abind_1.4-5          lifecycle_1.0.3      stringi_1.7.8       
 [64] multcomp_1.4-20      MASS_7.3-58.1        pkgbuild_1.4.0      
 [67] plyr_1.8.8           grid_4.2.2           parallel_4.2.2      
 [70] promises_1.2.0.1     crayon_1.5.2         miniUI_0.1.1.1      
 [73] lattice_0.20-45      splines_4.2.2        SBC_0.2.0.9000      
 [76] ps_1.7.5             pillar_1.9.0         trialr_0.1.5        
 [79] igraph_1.5.1         markdown_1.7         estimability_1.4.1  
 [82] shinystan_2.6.0      reshape2_1.4.4       codetools_0.2-18    
 [85] stats4_4.2.2         rstantools_2.3.1.1   glue_1.6.2          
 [88] V8_4.2.2             RcppParallel_5.1.5   vctrs_0.6.4         
 [91] httpuv_1.6.6         purrr_1.0.2          gtable_0.3.1        
 [94] cachem_1.0.6         mime_0.12            xtable_1.8-4        
 [97] coda_0.19-4          later_1.3.0          survival_3.4-0      
[100] shinythemes_1.2.0    memoise_2.0.1        tidybayes_3.0.2     
[103] TH.data_1.1-1        ellipsis_0.3.2       bridgesampling_1.1-2
2 Likes

I’ve been working on an R package that can help with this kind of thing: StanEstimators. You could narrow down the issue and points of difference by comparing a known-good R implementation of the likelihood with the brms results, without having to use a custom sampler/simulation

3 Likes

StanEstimators looks interesting. I was about to learn more about how to use it, then your comment made me realize I should probably interrogate a single dataset at a time. I picked a seed which gave terrible rank statistics in the simulation, and I plotted the marginal posteriors (red) against the true parameters (blue). It looks like the treatment group labels in the data are getting switched around. I see this same pattern for multiple seeds. Is there a reason brms might reorder character labels? Should I use factors instead of character vectors for discrete variables?

Here is a reprex. This one only takes a couple minutes to run.

library(brms)
library(dplyr)
library(ggplot2)
library(posterior)
library(tibble)
library(tidyr)

# Define the model.
prior <- set_prior("lkj_corr_cholesky(1)", class = "Lcortime") +
  set_prior("normal(0, 1)", class = "b", coef = "groupgroup_1") +
  set_prior("normal(0, 1)", class = "b", coef = "groupgroup_2") +
  set_prior("normal(0, 1)", class = "b", coef = "groupgroup_3") +
  set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_1") +
  set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_2") +
  set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_3") +
  set_prior("normal(0, 1)", class = "b", dpar = "sigma", coef = "timetime_4")
formula <- brmsformula(
  formula = response ~ 0 + group + unstr(time = time, gr = patient),
  sigma ~ 0 + time
)

# Simulate a dataset from the prior.
set.seed(seed = 8L, kind = "Mersenne-Twister")
n_group <- 3L
n_patient <- 100L
n_time <- 4L
patients <- tibble(
  group = paste0("group_", rep(seq_len(n_group), each = n_patient)),
  patient = paste0("patient_", seq_len(n_group * n_patient))
)
data <- expand_grid(patients, time = paste0("time_", seq_len(n_time)))
data$response <- 0
x <- make_standata(formula, data, prior = prior)$X
beta <- rnorm(n = n_group, mean = 0, sd = 1)
names(beta) <- paste0("b_", colnames(x))
b_sigma <- rnorm(n = n_time, mean = 0, sd = 1)
names(b_sigma) <- paste0("b_sigma_timetime_", seq_len(n_time))
sigma <- exp(b_sigma)
correlation <- trialr::rlkjcorr(n = 1L, K = n_time, eta = 1)
i <- rep(seq_len(n_time), each = n_time)
j <- rep(seq_len(n_time), times = n_time)
cortime <- as.numeric(correlation)[j > i]
names(cortime) <- sprintf("cortime__time_%s__time_%s", i[j > i], j[j > i])
covariance <- diag(sigma) %*% correlation %*% diag(sigma)
data <- data |>
  mutate(mu = as.numeric(x %*% beta)) |>
  mutate(index_patient = rep(seq_len(n_patient * n_group), each = n_time)) |>
  group_by(index_patient) |>
  mutate(response = MASS::mvrnorm(mu = mu, Sigma = covariance)) |>
  ungroup() |>
  select(-index_patient, -mu)
parameters <- c(beta, b_sigma, cortime)

# Run the model.
options(brms.backend = "rstan")
model <- brm(
  data = data,
  formula = formula,
  prior = prior,
  seed = 1L,
  chains = 4L,
  iter = 4000L,
  warmup = 2000L,
  refresh = 10L
)

# Visualize the fixed effect marginal posteriors against the data.
summary_model <- summarize_draws(model)
summary_fixed_model <- summary_model |>
  select(variable, mean, q5, q95) |>
  filter(grepl("group", variable))
z <- qnorm(p = 0.9)
summary_fixed_data <- data |>
  group_by(group) |>
  summarize(
    mean = mean(response),
    q5 = mean - z * sd(response) / sqrt(n()),
    q95 = mean + z * sd(response) / sqrt(n()),
    .groups = "drop"
  ) |>
  rename(variable = group) |>
  mutate(variable = paste0("b_group", variable))
summary_fixed <- dplyr::bind_rows(
  model = summary_fixed_model,
  data = summary_fixed_data,
  .id = "source"
)
summary_parameters <- tibble(
  variable = names(parameters),
  value = unname(parameters)
)
summary_parameters_fixed <- summary_parameters |>
  filter(grepl("group", variable))

ggplot(summary_fixed_data) +
  geom_point(
    aes(x = variable, y = mean),
    color = "red",
    position = position_dodge(width = 0.5)
  ) +
  geom_errorbar(
    aes(x = variable, ymin = q5, ymax = q95),
    color = "red",
    position = position_dodge(width = 0.5)
  ) +
  geom_point(
    data = summary_parameters_fixed,
    mapping = aes(x = variable, y = value),
    color = "blue",
    position = position_dodge(width = 0.5)
  )

So it looks like brms is reordering the rows of the model matrix! The first 20 rows of the data are all from group 1, but the brms model matrix switches to group 2 at row 13. Any tips on why brms reorders the rows and how to suppress this behavior?

Data:

> head(data, n = 20)
# A tibble: 20 × 4
   group   patient   time   response
   <chr>   <chr>     <chr>     <dbl>
 1 group_1 patient_1 time_1    0.663
 2 group_1 patient_1 time_2   -2.99 
 3 group_1 patient_1 time_3   -2.01 
 4 group_1 patient_1 time_4   -0.225
 5 group_1 patient_2 time_1    0.337
 6 group_1 patient_2 time_2   -2.81 
 7 group_1 patient_2 time_3   -0.687
 8 group_1 patient_2 time_4   -0.733
 9 group_1 patient_3 time_1    0.299
10 group_1 patient_3 time_2   -4.15 
11 group_1 patient_3 time_3   -0.370
12 group_1 patient_3 time_4   -0.211
13 group_1 patient_4 time_1    0.795
14 group_1 patient_4 time_2   -0.621
15 group_1 patient_4 time_3   -0.234
16 group_1 patient_4 time_4    2.36 
17 group_1 patient_5 time_1    0.270
18 group_1 patient_5 time_2    4.09 
19 group_1 patient_5 time_3    1.20 
20 group_1 patient_5 time_4    2.37 

brms model matrix:

> head(make_standata(formula = formula, data = data, prior = prior)$X, n = 20)
    groupgroup_1 groupgroup_2 groupgroup_3
1              1            0            0
2              1            0            0
3              1            0            0
4              1            0            0
37             1            0            0
38             1            0            0
39             1            0            0
40             1            0            0
397            1            0            0
398            1            0            0
399            1            0            0
400            1            0            0
401            0            1            0
402            0            1            0
403            0            1            0
404            0            1            0
405            0            1            0
406            0            1            0
407            0            1            0
408            0            1            0

I guess sorting the rows isn’t a problem for normal usage, but I do use make_standata() to simulate data for the simulation study, so it surprised me there. I modified the code to create a model matrix that respects the original row order:

stan_data <- make_standata(formula, data, prior = prior)
undo_brms_permutation <- match(data$response, stan_data$Y)
stopifnot(all(stan_data$Y[undo_brms_permutation] == data$response))
model_matrix <- stan_data$X[undo_brms_permutation, ]

And then I got results that look much more reasonable:

I think this should fix the original SBC study.