NUTS variation

Hi there!

I was thinking about ways to speedup NUTS through parallelisation - and here is a possible strategy for which I would appreciate some feedback.

The parallelisation possibilities are very limited by nature of the algorithm, of course. The obvious thing to parallelise are the forward and backward sweeps of NUTS. Whenever these happen in sequence, then these can be run in parallel. However, the issue is that we increase the tree depth in every iteration which severely limits the parallel runtime… but is this really needed?

So my question is: Can we run the NUTS loop in the usual way, but increase the tree depth only every other loop iteration?

If that would be possible, then we could sample at the beginning of the loop the direction of the first and the second sweep. In 50% of the cases those can be run in parallel and in 50% of the cases we go twice in the same direction and can’t parallelise. The speedup is 2x for the first case and nothing (so 1x) for the second case such that the total speedup can be up to 1.5x.

Is there anything obvious which makes people raise red flags here?

Ideally we can then in the future use 3 cores to calculate 2 chains about (up to) 50% faster at the end of the day.



Naming convention point: we are long past the “NUTS” algorithm. There will be a renaming around Stan3 but for now just refer to it as “dynamic HMC”.

In any case it’s better to think about the possible parallelization speedup this way. First speculatively integrate forwards and backwards in time and then run the sampler to consume those speculative trajectories. If you fully consume one of the trajectories then you can expand the speculation and continue.

Because the expansion is multiplicative, however, you’re likely to end up wasting a bunch of that speculative computation on both sides, yielding a much smaller speedup than 1.5 on average. Moreover you’ll have to keep all of those speculative states in memory which will be a significant burden for higher-dimensional problems. One of the big advantages of multiplicative expansion is needed only a logarithm number of states in memory at any given time.

Parallelizeable resources are much better spent on speeding the gradient evaluation or running multiple chains in memory to pool adaptation information.


Just to be clear here on the matter I am suggesting to change… I would like to deviate a little from the usual multiplicative expansion. Instead of making the tree depth one more at each iteration (what is done now), I am suggesting to increase the tree depth at every other iteration. Would that violate detailed balance in an obvious way?

I know that the ideal average speedup is 1.5x and I have to expect less - but this appears to me an easy thing to tryout.

Pooling adaption info also sounds very attractive, I agree; gradient evaluation parallelisation is sort of already there and we are improving it.


EDIT: iteration above refers to the loop iterations which are done during one dynamic HMC transition.

Each iteration in the current version of dynamic HMC is defined by a tree depth increase. Any additional states added that don’t come from a tree depth increase significantly complicate the termination checks.

Any change to the sampler requires significant overhead, and a change requiring threading through threading functionality (sorry, couldn’t help myself) is all the more onerous. There’s not much to squeeze out here so I would be very hesitant to move in that direction in the immediate future.

Then again the code is open for anyone to experiment with and report results!