Within-chain parallelization idea (maybe crazy)

The recent discussion about setting expectations for reduce_sum got me thinking. Is there any chance that a more consistent speedup is available by computing the forward and backward parts of the HMC trajectory in parallel, so that whenever the tree expansion proceeds in the direction where the trajectory is so-far shorter, the work has already been done? On the negative side, this approach would throw out a non-trivial amount of integration in one direction or the other at every iteration. Nevertheless, this approach might yield worthwhile speedup, provided that the information necessary to evaluate the stopping criterion can be efficiently copied/shared between processes.

Provided that the stopping criterion can be evaluated efficiently, it seems like there would be lots of cases where this approach yields useful speedup even though reduce_sum cannot (i.e. if the gradient evals are fast but the treedepths are large). And it seems like it could be implemented entirely on the backend, which no changes required to a user’s Stan code. And with more cores available it could be combined with reduce_sum for additional speedup.

Just a thought. Has this already been discussed and ruled out? Is it worth thinking about further?

1 Like

@betanalpha @nhuurre

1 Like

Tagging @wds15 who did this a while ago! Tmk it was like a 15% speedup?


FWIW there are scenarios where I would love to get my hands on 15% speed-up if I can do it without modification to my Stan code.


The conclusion was that one would get about 35% speedup when doubling the required number of CPUs. The user would not need to change anything in the Stan model and the optimization would apply to any Stan model.

While I personally thought this would be a good thing, the final decision was to not implement it given the increased code complexity and rather little gain in speed compared to the required Ressource needs of doubling the number of cpu cores.

Maybe this is worthwhile to reassess given the we are getting more and more cpu cores even on laptops? To me the key argument for this optimization is that it applies to all models and nothing needs to be changed (other than throwing in more cores).

Edit: there is a fully working implementation of this (dirty)


I strongly suspect that this efficiency gain per additional resource is on par with or exceeds quite a few applications of reduce_sum even for the first doubling of resources. But importantly, the relevant question isn’t whether 35% gain is competitive with reduce_sum for the first doubling of resources, but rather whether it’s competitive with reduce_sum for the last doubling. If I have 8 cores per chain to play with, and reduce_sum gives me some gain at 4 cores per chain, then the question is how often I would expect reduce_sum to be competitive with a 35% gain associated with increasing from 4 to 8 cores per chain. In my own work, I have literally never seen a 35% gain from 4 to 8 with reduce_sum, let alone from 8 to 16 or beyond.

I could try to bring the prototype up to date with the current 2.28 and then you try a few models? Maybe having more convincing power user examples makes it worthwhile to reconsider this and go for a community vote? Back then we did not have a vote on this due to a lack for this process if I recall correctly.

… but others would be really welcome in helping to make progress (@nhuurre ?)…


In my own work, I have literally never seen a 35% gain from 4 to 8 with reduce_sum , let alone from 8 to 16 or beyond.

This is actually to be expected. Likelihood parallelization only starts to give benefits when (1) the likelihood is a product of a huge number of data points, (2) the Hamiltonian trajectory is rather short. “Huge” here is in bold because evaluating the individual data points is often really cheap. So the constant cost of parallelizing anything dominates most of the models people will want to run on laptops. Amdahl’s law kicks hard on these models.

At the moment, there is no genuine HMC-compatible solution that achieves decent speedup through parallelization.


If it’s worth considering a non-laptop use case, I would definitely be interested. We are trying to scale a model to at least 1M parameters, where the grad evals are done on 100+ GPUs (e.g. w/ 200 GPUs, grad eval ~170ms), and just running the sampler in Stan on a separate plenty-core node, we run into sampler overhead (~170ms per leap-frog step with just model { }). While some cache effects are bound to bite into it, it’d be great to trade cores for a bit of speed. Which version of CmdStan might the above base_nuts work on, I’d try it for a spin? Would the 35% improvement cumulate with increasing core counts (modulo cache effects)?


There would be no speedup from this beyond a doubling of the core count, though it could be combined with reduce_sum for further speedup. The maximum theoretical speedup associated with doubling the core count would be related to the expectation of the fraction of iterations that take place on the shorter side of the tree. And then the overhead will be an additional penalty.


@wds15 I wonder if it would be easy to just make a parallel_base_nuts class and then make a service route for it / connect it up to cmdstan?

1 Like

Yeah, I think we should make this feature available to power users. To start to scope it, it would be really cool to have a downloadable cmdstan repo available so that we can collect some feedback on a prototype version.

How we set this up I haven’t really thought about, but you are right in that a parallel_base_nuts class is a good idea which is then wired up into cmdstan as an option to hmc… though we need to think about how to do that as this parallel thing is usable in all variants of hmc, but let’s start with the one where we do the adaptation.

I can bring up the old version in line with the current develop and then you jump in? Any help would be much appreciated given my lack of time.

Note that the version I wrote a while back uses the less sophisticated stopping criterion as it was not clear to me how to parallelise the most recent version (which looks more often at the “other side” of the tree), but others with deeper understanding of nuts said that we can get this to work for the parallel pattern as well. For now, I’d say that we start with what we have in store now. Plan?

Yes all sounds good to me!

Hi @stevebronder !

That was faster and easier than I thought! Here is the Stan branch which currently replaces base_nuts with the speculative NUTS algorithm:

This branch is straight of the current develop. Now, I have to say that I did minimal testing on the bernoulli example where I get the same results. Things left to be done & note:

  • This is the “older” iteration of NUTS using a less sophisticated stopping criteria. So don’t expect results to match with current develop!
  • Ideally this thing gets first applied on a problem where one knows the solution before embarking into the wild.
  • The speculative NUTS works by running the forward and the backward sweep at the same time. This is relatively wasteful and leads to maximal speedup of 35%, which you roughly reach at a treedepth of 7 as I recall. There are some simulations from me on this somewhere on the forum.

I would actually suggest to have now a prototype of this right away. If that’s something people want, then we do a more clean prototype where we even split out the parallel_nuts and make this into an option for cmdstan to choose from. Then this prototype can get broader testing and then we decide if we turn it into production code.

I hope others can really help in making this a reality - I am more than happy to advise on things…but I am short on time to do it myself in reasonable time.



Awesome I’ll take a look!

Had a first pass at this, not compiling everything yet but just trying to sort out a few things. Instead of having a mutex around the uniform rng sampler I just have 1 uniform rng sampler per thread. That and I pulled this out to be it’s own class (base_parallel_nuts) and am writing the service API route rn.

@maedoc once we have the cmdstan API layer working for this could I ping you to try it out? (also if you can use a lot of vectors/ matrices and do lots of vectorized ops in your model code turning on --O1 in the latest Stan compiler may speed things up for you as well). Do you know how deep your average tree depth is per transition? The scheme here mostly help when you have a deep-ish tree depth ( > 6). Running multiple chains within one stan program using the cmdstan num_chains argument may help out as well since that will let each gradient calculation across all the chains share data.

@wds15 once I get the below passing some unit tests I’ll put up a PR and then make the cmdstan API

1 Like

sure I’d be up to try it out. Thanks for the tip.

Yes, 8 to 10.


I remember that I struggled with the rng. I came to the conclusion that the mutex does not hurt performance, since we very rarerly draw random numbers during sampling. The solution to have one rng per thread is not ideal in terms of reproducibility. Maybe sampling the left/right turns in advance is a better choice here… but we can revisit that at a later stage. How about adding a note to the source that we need to think about rng number handling a bit more?

Really cool to see this moving!

1 Like

If you want a model (+R code to generate data) that meets these criteria, see here (it’s fundamentally the SUG1.13 model, but highly optimized)

1 Like