Can you specify burnin per chain after sampling in brms? I am running an analysis where some chains take longer than the warmup to reach convergence. I’d like to select the samples from when they do and onwards, i.e. a non-equal number of samples from different chains. The model is quite computationally heavy (around 1 minute per sample per chain), so I’d really like to not just set a longer warmup and skip “good sampling time”.
Here is a caterpillar plot of 6 chains. It is clear when each individual chain reaches the convergence-region:
Windows 10, brms 2.6.3
I managed to solve this problem using the function I wrote below.
Before processing, sigma
traceplot looked like this:
The resulting traceplot look like this:
The strategy was rather simple: I merged the “good” samples (after convergence) from the chains into one chain and trick stan/brms into thinking that this fit was a one-chain-sampling to begin with. stan
and brms
cannot handle chains of differing length, so this is necessary.
Here is a function to do it. It’s for brms
objects, but since brm_fit$fit
is the stanfit
, only minor adaptations are needed to make it work with “raw” stan.
library(brms)
library(tidyverse)
# the function!
remove_unconverged = function(brm_fit, convergence_start) {
# brm_fit is the output of brms::brm
# convergence_start is a vector of length [N_chains] with the iteration number to skip until
# For each chain, cut away the first N samples...
sim = brm_fit$fit@sim # Handy shortcut
for(chain in 1:sim$chains) {
sim$samples[[chain]] = data.frame(sim$samples[[chain]]) %>%
slice((sim$warmup + convergence_start[chain]):n()) # Also remove warmup
}
# Merge all samples into one chain
sim$samples = list(bind_rows(sim$samples))
# Update the meta-info
sim$warmup = 0
sim$warmup2 = rep(0, sim$chains) # 0 warmup for each chain
sim$chains = 1
sim$iter = nrow(sim$samples[[1]]) # Total number of "accepted" iterations
sim$n_save = sim$iter
stan_args = brm_fit$fit@stan_args[[1]] # Handy shortcut
stan_args$warmup = 0
stan_args$iter = sim$iter
# Add the modified sim back to x
brm_fit$fit@sim = sim
brm_fit$fit@stan_args = list(stan_args)
brm_fit
}
And here’s the example that generated the plots above:
x = brm(file = 'fit_rt_exp_full') # Load the saved data
plot(x, pars='sigma') # Plot it to determine where each chain reaches convergence
x = remove_unconverged(x, c(200, 100)) # Remove unwanted initial samples
plot(x, pars='sigma') # See the result
1 Like
I wouldn’t do this. For starters, judging convergence by looking at traceplots is precarious. If you merge chains together while it is still in the warmup phase, they are not valid Markov chains. If you merge chains together after the warmup phase, then you are throwing away draws. If you merge the chains together right as it transitions from the warmup phase to the post-warmup phase, then you have the default behavior of rstan, except that you have messed up its good convergence diagnostics.
1 Like
Thanks for the clarification. remove_unconverged
throws away draws in the post-warmup phase. I agree that judging convergence can be dodgy for drifting chains or with only 2-3 chains where you have little knowledge about which one is the “right” one, such as chain 2 above.
However, my problem at hand is the sampling below. This took ~80 hours to run after warmup, and I feel pretty confident identifying when chain 5 and 6 began sampling probable values (conditioned on the rest of the model).
1 Like
You mentioned it took ~80 hours. Once you used this merging technique what was the run time?
This is post-sampling, so it doesn’t affect run time. But I avoided having to delete of all iterations before iteration all chains converged around iteration 250. In other words, I kept 11% more samples than I otherwise would have, corresponding to deleting 5 hours of computation less.
In the same vein here is a function to drop certain chains
remove_chains = function(brm_fit, chains_to_drop) {
# brm_fit is the output of brms::brm
sim = brm_fit$fit@sim # Handy shortcut
sim$samples <- sim$samples[-chains_to_drop]
# Update the meta-info
sim$chains = sim$chains - length(chains_to_drop)
sim$warmup2 = sim$warmup2[-chains_to_drop]
# Add the modified sim back to x
brm_fit$fit@sim = sim
brm_fit
}
4 Likes