Thank you @wds15.
I implemented both the multinomial model with reduce_sum() and the equivalent non-parallelized model. Whatever parameter I tweak, all the parallelized runs take longer than the non-parallelized ones. I tried:
- setting threads_per_chain to 1 → highest computation time, way higher than non-parallelized
- increasing threads_per_chain from 2 to 10 → increases computation time. all times higher than non-parallelized example but lower than run with threads_per_chain = 1
- testing grainsize, starting value at K/N_{cores} (where K is the number of log-likelihood sum components) then iteratively dividing by 2 → increases computation time
- increasing K → increases computation time
To sense check that this wasn’ t a problem with my computing system I ran the binomial example Reduce Sum: A Minimal Example (mc-stan.org) and got similar results to the ones obtained there.
Could there be something about the multinomial distribution that makes parallelization gains impossible?
Cheers
Code:
// simplest_multinomial.stan
data {
int<lower=0> D;
int<lower=0> K;
int<lower=0> Y[K,D];
}
parameters {
simplex[D] theta[K];
}
model {
for(k in 1:K){
target += multinomial_lpmf(Y[k] | theta[k]);
}
}
multithread version:
// simplest_multinomial_redsum.stan
functions{
real partial_sum(int[,] y_slice,
int start,
int end,
vector[] theta) {
real interm_sum = 0;
for (i in 1: end-start+1){
interm_sum += multinomial_lpmf(y_slice[i] | theta[start+i-1]);
}
return interm_sum;
}
}
data {
int<lower=0> D;
int<lower=0> K;
int<lower=0> Y[K,D];
int grainsize;
}
parameters {
simplex[D] theta[K];
}
model {
target += reduce_sum(partial_sum,
Y,
grainsize,
theta);
}
R code to test a dummy example:
library(cmdstanr)
library(posterior)
ncores = 4
options(mc.cores = ncores)
# create fake data
K <- 500
Y <- t(rmultinom(K, 100, c(0.8,0.1,0.1)))
stan_data <- list(D = 3,
K = K,
Y = Y)
# basic
mod0 <- cmdstan_model("./simplest_multinomial.stan")
time0 = system.time(
fit0 <- mod0$sample(data = stan_data,
seed = 5446,
chains = 2,
parallel_chains = 2,
iter_warmup = 500,
iter_sampling = 500,
refresh = 100)
)
# multithread
stan_data$grainsize <- 1
mod1 <- cmdstan_model("./simplest_multinomial_redsum.stan",
cpp_options = list(stan_threads = TRUE))
time1 = system.time(
fit1 <- mod1$sample(data = stan_data,
seed = 5446,
chains = 2,
parallel_chains = 2,
threads_per_chain = 2,
iter_warmup = 500,
iter_sampling = 500,
refresh = 100)
)
# compare
time1[["elapsed"]] / time0[["elapsed"]]