Help needed with CmdStanR parallel chains

Hi all, I hope you can help me with how parallel chains works in CmdStanR. I’m getting what I feel are counterintuitive results but it’s because I don’t understand what is happening.

I’m running a very basic Bayesian linear model of

y_i \sim N(\beta_0 + \beta_1 z_i, \sigma^2)

where Z is an indicator variable. I am using simulated data and the default uniform priors for the parameters.

My Stan code is as follows


data {
  
  // Number of participants in the analysis
  int<lower=0> int_N;
  
  // Treatment arm dummy indicator
  vector<lower=0, upper=1>[int_N] vec_z;
  
  // Outcome vector
  vector[int_N] vec_y;
  
}


parameters {
  
  real b0;
  real b1;
  real<lower=0> sigma;
  
}

model {
  
  vec_y ~ normal(b0 + b1* vec_z, sigma);
  
}

My simulation code is as follows


test.stanmodel <- cmdstanr::cmdstan_model("testmcmcmodel.stan")

rm(temp.stanmodel.mcmc)
set.seed(1234)

starttime <- Sys.time()

# Pass the data over to Stan
temp.N <- 120
temp.z <- c(rep(0,temp.N/2), rep(1, temp.N/2))
temp.y <- rnorm(temp.N, 0, 1)

temp.stanmodel.data <- list(
  int_N = temp.N,
  vec_z = temp.z,
  vec_y = temp.y
)

temp.stanmodel.mcmc <- test.stanmodel$sample(data = temp.stanmodel.data, 
                                                       chains = 5, 
                                                       parallel_chains = 1, 
                                                       iter_warmup = 500, 
                                                       iter_sampling = 1e3,
                                                       show_messages = F)

endtime <- Sys.time()
runtime <- endtime - starttime
runtime

Queries

My main concern with my post is with runtime.

  1. So when I run this code with parallel_chains = 1, I get a runtime of 59.6 seconds. Then when I changed it to parallel_chains = 5 and re-ran the code with the same seed, I get a runtime of 58.5 seconds. This is counterintuitive to me because I thought parallelisation would give me a runtime that’s magnitudes shorter but this is clearly wrong. What is going on here with the parallel_chains argument?
  2. On a separate note, I don’t understand the relationship between the number of chains, iter_sampling, and runtime. Suppose I want a total of 5000 samples from the joint posterior. I compared chains=5, iter\_sampling=1000 versus chains=2, iter\_sampling=2500. For some reason, running fewer chains at a higher iter_sampling gives a faster runtime regardless of whether parallel_chains is set to the number of chains or to 1.
  3. What can I do to reduce my runtime?

Thanks for reading my post.

Edit: 58.5 seconds, not minutes

The details of your hardware matter here. My guess is that either you are encountering memory pressure when running multiple chains in parallel that grinds everything to a halt, or that some of the chains are being forced onto much slower cores.

When parallel_chains is set to 1, fewer chains with more sampling iterations will give a faster runtime by paying the cost of warmup cost fewer times.