Parallelization in Stan's models

Hello everyone
I have a few general questions that I would appreciate if anyone with information on this topic could answer (most likely these questions have arisen or will arise for many people).

My questions are about parallelization in Stan’s models, so I would be happy if anyone has information on this topic (even outside of the questions I raised) could answer it in this post so that everyone can benefit.

I must say that I am aware that in the model fitting function itself (such as the “stan() function” in the “rstan” package or the “mod$sample() function” in the “cmdstanr” package) it is possible to parallelize with the help of an option (such as “cores = getOption(“mc.cores”, 20)”).

  1. But is parallelization possible except within the model, with functions such as (reduce_sum)?

  2. Is there much difference in terms of execution speed? Between parallelization alone within the model versus parallelization within and using functions such as “reduce_sum” simultaneously.

  3. And if we want to do parallelization with the help of some functions or packages in Stan, how do I do it? (I mean a clear and good guide that explains step by step how to perform parallelization)

Thank you in advance for your response.

Hey Mohammad, I’ll give it a crack. The parallelisation arguments such as cores refers to the number of cores used to fit the model, but without within-chain parallelisation options you’re essentially capped by the number of chains. For instance, if you set chains = 4, parallel_chains = 4, cmdstanr runs 4 chains in parallel. More chains obviously means more posterior draws and you can check for convergence with different initial values.

Within-chain parallelisation can split the work of a single chain across multiple cores, so if you have lots of observations you can spread, for instance, the log likelihood computations of all of your observations across multiple cores. For instance, if you have a reduce_sum() in your Stan program and you set chains = 4, parallel_chains = 4, threads_per_chain = 5 (assuming you have 20 cores), each chain is able to use 5 cores for the computations. There’s a trade-off as there’s overhead involved with reduce_sum(), but you should see substantial speed improvements if you have lots of observations.

As for examples, there are some in the User’s Guide.

1 Like