Cmdstanpy, mpi speedup

which function is within an iteration that I can parallelize for Stan?

Saying “found a way” makes it seem both harder and easier than it was. Both PyMC and NumPyro can generate JAX output, which can be run on CPUs, GPUs, TPUs. To perform inference in JAX, you’d probably need to use something like Blackjax rather than, say, the nutpie implementation in Rust that is PyMC’s go-to sampler these days (it also runs Stan models).

Coding Stan to produce JAX output would be a lot of work on our side, and would also limit what could be written in Stan, because JAX output is less flexible. I’m also not convinced a language like Stan would be the best way to write JAX code because we don’t have the parallelization built into the language in the way you can control it in JAX directly.

There are many functions that have GPU support in Stan. But if you’re talking about parallelizing something like a loop, you can use map_rect or reduce_sum to do that. That will parallelize over threads on a single machine and can also be set up to parallelize using MPI over multiple machines. It tends not to be worth doing this unless the amount of computation being parallelized dwarfs the communication cost.

I use map_rect() already and the performance is optimized for <5 MPI threads. I have many more threads available but no way to utilize them for a speedup in the MCMC sampling within Stan… The gradient evaluation is a few seconds. But each iteration (defined by num_samples) takes many minutes…and there’s thousands of iterations… If I can make an iteration a few seconds by parallelizing the sampling method, that would be great. I imagine an iteration is just a bunch of gradient evaluations. This sampling logic is something fundamental to Stan (not my model) and thus, I am asking which functions are parallelizable to make an iteration (sampling) faster. GPU support was mentioned but no answers given… besides, what if one wanted to use MPI and CPUs? It does not seem to be a viable route for Stan even if I knew the functions that had GPU support… Although it would still be interesting to know which functions had GPU support- maybe the sampling can be sped up there if I made a fork of the project. I am yet to see any documentation about that topic (parallelization of sampling, not the parallelization for gradient evaluation time minimization).

Regardless, somehow the PyMC/JAX/Blackjax setup seems to be able offload more work per thread in the MCMC sampling, as seen in the article above. …

We find that people generally generate more iterations than they need? What’s your target effective sample size? We usually recommend around 100 because that reduces standard to be about 1/10th of the posterior standard deviation, and the latter is the real uncertainty that won’t go away with more sampling.

The iterations use the leapfrog integrator to solve the Hamiltonian dynamics. Each step of the leapfrog integrator uses a gradient of the target density. We cap the max leapfrog steps to 1024 by default by setting max_treedepth=10 (the 1024 are organized into a 10-deep balanced binary tree).

None. The leapfrog integrator is intrinsically serial. That’s why you haven’t seen any discussion of parallelizing it. The best you can do is run multiple chains. So if you are only using 4 chains now with 1000 sampling iterations each, the logical limit is 1000 chains with 1 sampling iteration each. Then it’s just a matter of how short warmup can be. And of course there are intermediate solutions. The thing to watch out for is memory bottlenecks here.

Otherwise, the only thing to do is make the log density and gradient evaluation faster.

Absolutely, which is why I’ve been learning how to code models directly in JAX. NumPyro can also target JAX on the back end. I’ve found it easier to just write densities Stan style in JAX. I find the graphical model orientation of both NumPyro and PyMC distracting (though there are advantages for workflow in sticking to graphical models). Here’s an example of what I’m working on now—I’m working in JAX because I want to evaluate normalizing flows.

You see the same kind of organization within the Inference Gym project from TensorFlow Probability.

You mean you get the best performance with fewer than 5 MPI instances? MPI is usually implemented with multiple processes rather than with threads. The message passing in the name is usually between processes.

you mentioned something very interesting here.

  1. you mean instead of running 100 iterations for 1 chain, I can run 1 iteration for 100 chains, and get the same solution?

  2. Can I similarly run 100 chains with 1 warmup iteration, instead of 1 chain and 100 warmup iterations?

This gives rise to parallelization, so I can utilize my many threads available. I can run 1 chain on 1 thread.

  1. Yes. It won’t be the exact same solution, but it’ll have roughly the same effective sample size if the autocorrelation time on the chains is one iteration (in the usual case of positive correlation, running 100 chains for 1 iteration will be better). The other practical problem is that with 100 chains, there’s more of a chance of one of them getting “stuck” in a nasty region of the posterior, but this is probably good to know.

  2. No. You have to run long enough that you get a proper sample from (1). I’d suggest looking at this paper by Matt Hoffman et al., which discusses the issue:

It’s going to be a chapter in the 2nd edition of the MCMC Handbook (from CRC). It’s also a fun tutorial on JAX :-).

Great article; it mentions one chain per thread, exactly how I am set up. Good to know. Maybe Stan is fine if taken this approach. So you can parallelize the sampling :).

Regarding the warmup though: I am trying to think how to speedup the warmup. My gradient evaluation is optimized using map_rect; I am looking at 100 iterations of warmup. Any undocumented tricks/hints are welcome. Thanks,