Select chains and iterations post-sampling in brms


#1

Can you specify burnin per chain after sampling in brms? I am running an analysis where some chains take longer than the warmup to reach convergence. I’d like to select the samples from when they do and onwards, i.e. a non-equal number of samples from different chains. The model is quite computationally heavy (around 1 minute per sample per chain), so I’d really like to not just set a longer warmup and skip “good sampling time”.

Here is a caterpillar plot of 6 chains. It is clear when each individual chain reaches the convergence-region:

image

Windows 10, brms 2.6.3


#2

No


#3

I managed to solve this problem using the function I wrote below.

Before processing, sigma traceplot looked like this:
image

The resulting traceplot look like this:
image

The strategy was rather simple: I merged the “good” samples (after convergence) from the chains into one chain and trick stan/brms into thinking that this fit was a one-chain-sampling to begin with. stan and brms cannot handle chains of differing length, so this is necessary.

Here is a function to do it. It’s for brms objects, but since brm_fit$fit is the stanfit, only minor adaptations are needed to make it work with “raw” stan.

library(brms)
library(tidyverse)

# the function!
remove_unconverged = function(brm_fit, convergence_start) {
  # brm_fit is the output of brms::brm
  # convergence_start is a vector of length [N_chains] with the iteration number to skip until
  
  # For each chain, cut away the first N samples...
  sim = brm_fit$fit@sim  # Handy shortcut
  for(chain in 1:sim$chains) {
    sim$samples[[chain]] = data.frame(sim$samples[[chain]]) %>%
      slice((sim$warmup + convergence_start[chain]):n())  # Also remove warmup
  }
  
  # Merge all samples into one chain
  sim$samples = list(bind_rows(sim$samples))
  
  # Update the meta-info
  sim$warmup = 0
  sim$warmup2 = rep(0, sim$chains)  # 0 warmup for each chain
  sim$chains = 1
  sim$iter = nrow(sim$samples[[1]])  # Total number of "accepted" iterations
  sim$n_save = sim$iter
  
  stan_args = brm_fit$fit@stan_args[[1]]  # Handy shortcut
  stan_args$warmup = 0
  stan_args$iter = sim$iter
  
  # Add the modified sim back to x
  brm_fit$fit@sim = sim
  brm_fit$fit@stan_args = list(stan_args)
  
  brm_fit
}

And here’s the example that generated the plots above:

x = brm(file = 'fit_rt_exp_full')  # Load the saved data
plot(x, pars='sigma')  # Plot it to determine where each chain reaches convergence
x = remove_unconverged(x, c(200, 100))  # Remove unwanted initial samples
plot(x, pars='sigma')  # See the result

#4

I wouldn’t do this. For starters, judging convergence by looking at traceplots is precarious. If you merge chains together while it is still in the warmup phase, they are not valid Markov chains. If you merge chains together after the warmup phase, then you are throwing away draws. If you merge the chains together right as it transitions from the warmup phase to the post-warmup phase, then you have the default behavior of rstan, except that you have messed up its good convergence diagnostics.


#5

Thanks for the clarification. remove_unconverged throws away draws in the post-warmup phase. I agree that judging convergence can be dodgy for drifting chains or with only 2-3 chains where you have little knowledge about which one is the “right” one, such as chain 2 above.

However, my problem at hand is the sampling below. This took ~80 hours to run after warmup, and I feel pretty confident identifying when chain 5 and 6 began sampling probable values (conditioned on the rest of the model).


#6

You mentioned it took ~80 hours. Once you used this merging technique what was the run time?


#7

This is post-sampling, so it doesn’t affect run time. But I avoided having to delete of all iterations before iteration all chains converged around iteration 250. In other words, I kept 11% more samples than I otherwise would have, corresponding to deleting 5 hours of computation less.