I was thinking about ways to speedup NUTS through parallelisation - and here is a possible strategy for which I would appreciate some feedback.
The parallelisation possibilities are very limited by nature of the algorithm, of course. The obvious thing to parallelise are the forward and backward sweeps of NUTS. Whenever these happen in sequence, then these can be run in parallel. However, the issue is that we increase the tree depth in every iteration which severely limits the parallel runtime… but is this really needed?
So my question is: Can we run the NUTS loop in the usual way, but increase the tree depth only every other loop iteration?
If that would be possible, then we could sample at the beginning of the loop the direction of the first and the second sweep. In 50% of the cases those can be run in parallel and in 50% of the cases we go twice in the same direction and can’t parallelise. The speedup is 2x for the first case and nothing (so 1x) for the second case such that the total speedup can be up to 1.5x.
Is there anything obvious which makes people raise red flags here?
Ideally we can then in the future use 3 cores to calculate 2 chains about (up to) 50% faster at the end of the day.