Hi all, I hope you can help me with how parallel chains works in CmdStanR. I’m getting what I feel are counterintuitive results but it’s because I don’t understand what is happening.
I’m running a very basic Bayesian linear model of
where Z is an indicator variable. I am using simulated data and the default uniform priors for the parameters.
My Stan code is as follows
data {
// Number of participants in the analysis
int<lower=0> int_N;
// Treatment arm dummy indicator
vector<lower=0, upper=1>[int_N] vec_z;
// Outcome vector
vector[int_N] vec_y;
}
parameters {
real b0;
real b1;
real<lower=0> sigma;
}
model {
vec_y ~ normal(b0 + b1* vec_z, sigma);
}
My simulation code is as follows
test.stanmodel <- cmdstanr::cmdstan_model("testmcmcmodel.stan")
rm(temp.stanmodel.mcmc)
set.seed(1234)
starttime <- Sys.time()
# Pass the data over to Stan
temp.N <- 120
temp.z <- c(rep(0,temp.N/2), rep(1, temp.N/2))
temp.y <- rnorm(temp.N, 0, 1)
temp.stanmodel.data <- list(
int_N = temp.N,
vec_z = temp.z,
vec_y = temp.y
)
temp.stanmodel.mcmc <- test.stanmodel$sample(data = temp.stanmodel.data,
chains = 5,
parallel_chains = 1,
iter_warmup = 500,
iter_sampling = 1e3,
show_messages = F)
endtime <- Sys.time()
runtime <- endtime - starttime
runtime
Queries
My main concern with my post is with runtime.
- So when I run this code with parallel_chains = 1, I get a runtime of 59.6 seconds. Then when I changed it to parallel_chains = 5 and re-ran the code with the same seed, I get a runtime of 58.5 seconds. This is counterintuitive to me because I thought parallelisation would give me a runtime that’s magnitudes shorter but this is clearly wrong. What is going on here with the parallel_chains argument?
- On a separate note, I don’t understand the relationship between the number of chains, iter_sampling, and runtime. Suppose I want a total of 5000 samples from the joint posterior. I compared chains=5, iter\_sampling=1000 versus chains=2, iter\_sampling=2500. For some reason, running fewer chains at a higher iter_sampling gives a faster runtime regardless of whether parallel_chains is set to the number of chains or to 1.
- What can I do to reduce my runtime?
Thanks for reading my post.
Edit: 58.5 seconds, not minutes