NUTS variation

Hi there!

I was thinking about ways to speedup NUTS through parallelisation - and here is a possible strategy for which I would appreciate some feedback.

The parallelisation possibilities are very limited by nature of the algorithm, of course. The obvious thing to parallelise are the forward and backward sweeps of NUTS. Whenever these happen in sequence, then these can be run in parallel. However, the issue is that we increase the tree depth in every iteration which severely limits the parallel runtime… but is this really needed?

So my question is: Can we run the NUTS loop in the usual way, but increase the tree depth only every other loop iteration?

If that would be possible, then we could sample at the beginning of the loop the direction of the first and the second sweep. In 50% of the cases those can be run in parallel and in 50% of the cases we go twice in the same direction and can’t parallelise. The speedup is 2x for the first case and nothing (so 1x) for the second case such that the total speedup can be up to 1.5x.

Is there anything obvious which makes people raise red flags here?

Ideally we can then in the future use 3 cores to calculate 2 chains about (up to) 50% faster at the end of the day.



Naming convention point: we are long past the “NUTS” algorithm. There will be a renaming around Stan3 but for now just refer to it as “dynamic HMC”.

In any case it’s better to think about the possible parallelization speedup this way. First speculatively integrate forwards and backwards in time and then run the sampler to consume those speculative trajectories. If you fully consume one of the trajectories then you can expand the speculation and continue.

Because the expansion is multiplicative, however, you’re likely to end up wasting a bunch of that speculative computation on both sides, yielding a much smaller speedup than 1.5 on average. Moreover you’ll have to keep all of those speculative states in memory which will be a significant burden for higher-dimensional problems. One of the big advantages of multiplicative expansion is needed only a logarithm number of states in memory at any given time.

Parallelizeable resources are much better spent on speeding the gradient evaluation or running multiple chains in memory to pool adaptation information.


Just to be clear here on the matter I am suggesting to change… I would like to deviate a little from the usual multiplicative expansion. Instead of making the tree depth one more at each iteration (what is done now), I am suggesting to increase the tree depth at every other iteration. Would that violate detailed balance in an obvious way?

I know that the ideal average speedup is 1.5x and I have to expect less - but this appears to me an easy thing to tryout.

Pooling adaption info also sounds very attractive, I agree; gradient evaluation parallelisation is sort of already there and we are improving it.


EDIT: iteration above refers to the loop iterations which are done during one dynamic HMC transition.

Each iteration in the current version of dynamic HMC is defined by a tree depth increase. Any additional states added that don’t come from a tree depth increase significantly complicate the termination checks.

Any change to the sampler requires significant overhead, and a change requiring threading through threading functionality (sorry, couldn’t help myself) is all the more onerous. There’s not much to squeeze out here so I would be very hesitant to move in that direction in the immediate future.

Then again the code is open for anyone to experiment with and report results!

Any schedule should be OK. All we need is reversibility so that there’s the same chance of building the tree going the other way.

We generate a longer sequence of decisions than two. We can then just start going forward and backward asynchronously, evaluating the U-turn conditions as we hit the right spots, then terminating the builds.

The question’s really is it a better speedup than just running twice as many chains. Are you trying to speed up end-to-end wall clock time for a single chain or time to a given ESS?

My goal is to get the same number of effective samples in less wall clock time at the expense of using more cores.

So the usual 4 chain run on 4 cores would be run instead on 6 cores and you would get the result 1.5x faster.

…btw, why do we have to randomly choose the direction everytime when we grow deeper? Why not just randomly choose the first direction when the transition starts and then just go always in opposite directions. Then the speedup can reach 2x!

@betanalpha can you point me to some scripts which check detailed balance or any other checks needed to gain trust?

Because then we can’t dynamically determine when to stop. You can either determine how for to go in each direction for a fixed total integration time initially, or you can dynamically expand one tree depth at a time randomly varying the directions at each step. In order to preserve the target distribution with dynamic checks you have to have trajectories integrating in random directions.

Ultimately we have a non obvious choice. We can try to expand for fixed integration times using parallelization in each direction, hoping that the speculative computation isn’t mostly wasted. Or we can proceed as is an focus parallelization resources on gradient parallelization or chain parallelization. I strongly prefer the latter.

To verification the first, very weak, level is to turn the models in and ensure that the mean +/- se is close to the true expectation value for each parameter and generated quantity. You’ll want to run longer chains, at least 100,000 iterations, to be confident.

The trick’s thinking about reversibility of the algorithm. Can you get from A to B and back from B to A making the same random tree choices.

If you pack multiple decisions together, they have to resolve to something reversible.

Now you can generate a bunch of random directions to start and keep building, but you may do work that gets thrown away if the other side U-turns prematurely or you have to wait for it asynchronously.

Thanks. I did sit down with @bbbales2 during StanCon and he showed me the logic of how to check this. Apparently it does not workout the way I wished.

Thus, we must stick to the more simple approach of parallelising the current scheme. If we do that in a simple approach by looking at 2 depths at a time then we can get at most 25% speedup. If one samples the entire sequence of turns in advance then there should be potential for more speedup as the pairs which can be parallelised overlap… but getting the checks right in this logic is more involved; let’s see if I see a nice way of how to code it.