Parallelization in cmdstanr: using reduce_sum() on a multinomial likelihood

Thank you @wds15.

I implemented both the multinomial model with reduce_sum() and the equivalent non-parallelized model. Whatever parameter I tweak, all the parallelized runs take longer than the non-parallelized ones. I tried:

  • setting threads_per_chain to 1 → highest computation time, way higher than non-parallelized
  • increasing threads_per_chain from 2 to 10 → increases computation time. all times higher than non-parallelized example but lower than run with threads_per_chain = 1
  • testing grainsize, starting value at K/N_{cores} (where K is the number of log-likelihood sum components) then iteratively dividing by 2 → increases computation time
  • increasing K → increases computation time

To sense check that this wasn’ t a problem with my computing system I ran the binomial example Reduce Sum: A Minimal Example (mc-stan.org) and got similar results to the ones obtained there.

Could there be something about the multinomial distribution that makes parallelization gains impossible?

Cheers

Code:

// simplest_multinomial.stan

data {
  int<lower=0> D;
  int<lower=0> K;
  int<lower=0> Y[K,D];
}

parameters {
  simplex[D] theta[K];
}

model {
  for(k in 1:K){
    target += multinomial_lpmf(Y[k] | theta[k]);
  }
}

multithread version:

// simplest_multinomial_redsum.stan

functions{
  real partial_sum(int[,] y_slice,
                   int start,
                   int end,
                   vector[] theta) {
    real interm_sum = 0;
    for (i in 1: end-start+1){
      interm_sum += multinomial_lpmf(y_slice[i] | theta[start+i-1]);
    }
    return interm_sum;
  }
}

data {
  int<lower=0> D;
  int<lower=0> K;
  int<lower=0> Y[K,D];
  int grainsize;
}

parameters {
  simplex[D] theta[K];
}

model {
  target += reduce_sum(partial_sum,
                         Y,
                         grainsize,
                         theta);
}

R code to test a dummy example:

library(cmdstanr)
library(posterior)

ncores = 4
options(mc.cores = ncores)

# create fake data
K <- 500
Y <- t(rmultinom(K, 100, c(0.8,0.1,0.1)))

stan_data <- list(D = 3,
                  K = K,
                  Y = Y)

# basic
mod0 <- cmdstan_model("./simplest_multinomial.stan")

time0 = system.time(
  fit0 <- mod0$sample(data = stan_data,
                      seed = 5446,
                      chains = 2,
                      parallel_chains = 2,
                      iter_warmup = 500,
                      iter_sampling = 500,
                      refresh = 100)
)

# multithread
stan_data$grainsize <- 1

mod1 <- cmdstan_model("./simplest_multinomial_redsum.stan",
                      cpp_options = list(stan_threads = TRUE))


time1 = system.time(
  fit1 <- mod1$sample(data = stan_data,
                      seed = 5446,
                      chains = 2,
                      parallel_chains = 2,
                      threads_per_chain = 2,
                      iter_warmup = 500,
                      iter_sampling = 500,
                      refresh = 100)
)

# compare
time1[["elapsed"]] / time0[["elapsed"]]
2 Likes