Hi there!
I have been thinking about parallelizing our dynamic HMC sampler. The obvious strategy to do that is to
- pre-sample all the fwd and bck turns up to the maximal treedepth
- then expand the fwd and back trajectories in full independence
- all checks still happen in the same order as before
This scheme works best if we have exactly alternating turns, of course. In practice things are not that ideal such that, after discussing with @betanalpha, we thought its good to first assess the merits of this approach using a simulation. The simulation code I wrote is attached if you want the details. The average speedup defined as parallel/serial is then by treedepth
treedepth mean low_5% median high_95%
<int> <dbl> <dbl> <dbl> <dbl>
1 1 1 1 1 1
2 2 1.25 1 1.5 1.5
3 3 1.33 1 1.4 1.75
4 4 1.36 1 1.36 1.88
5 5 1.37 1 1.35 1.94
6 6 1.38 1.02 1.34 1.91
7 7 1.38 1.02 1.32 1.90
8 8 1.38 1.02 1.34 1.90
9 9 1.39 1.02 1.33 1.91
10 10 1.39 1.03 1.33 1.91
So on average we get for a treedepth >3 about of a ~35% speedup. That’s a bit below what I was hoping for, honestly, but this maximal speedup should apply to just about any non-trivial Stan model and the user would not need to change any line of Stan code. In terms of efficiency this is not too great as the ressources used are doubled to get the walltime down from 1 unit to 1/1.35 = 0.74 units.
The prototype I have coded up shows already promising results. Using the poisson hierarchical model I used for the tbb benchmarks (here I use 500 groups and 10 observations per group with a poisson-log-lik) I am getting a speedup of
## runtime in seconds
mean(c(83, 80)) / mean(c(66, 68)) = 1.22
The treedepth of the sampler is 6, so that we get from the theoretical speedup of 38% in practice 22%. Right now the prototype lacks exact reprodcability since the random number generator is accessed in a non-deterministc way (can be fixed with some caching), but overall the results are very close (though it’s odd to see slightly different stepsizes which make the comparison not ideal):
## serial summary output
Warmup took (48) seconds, 48 seconds total
Sampling took (35) seconds, 35 seconds total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -1.2e+04 3.7e+00 2.3e+01 -1.2e+04 -1.2e+04 -1.2e+04 39 1.1 1.0e+00
accept_stat__ 9.3e-01 2.5e-03 8.4e-02 7.4e-01 9.6e-01 1.0e+00 1131 32 1.0e+00
stepsize__ 5.0e-02 1.7e-16 1.7e-16 5.0e-02 5.0e-02 5.0e-02 1.0 0.029 1.0e+00
treedepth__ 6.0e+00 2.0e-16 4.4e-15 6.0e+00 6.0e+00 6.0e+00 500 14 1.0e+00
n_leapfrog__ 6.3e+01 6.4e-02 2.0e+00 6.3e+01 6.3e+01 6.3e+01 1004 29 1.0e+00
divergent__ 0.0e+00 0.0e+00 0.0e+00 0.0e+00 0.0e+00 0.0e+00 500 14 nan
energy__ 1.2e+04 3.6e+00 2.8e+01 1.2e+04 1.2e+04 1.2e+04 58 1.7 1.0e+00
log_lambda 1.6e+00 7.7e-03 4.4e-02 1.5e+00 1.6e+00 1.7e+00 32 0.93 1.0e+00
tau 1.2e+00 7.8e-03 4.0e-02 1.1e+00 1.2e+00 1.2e+00 26 0.75 1.0e+00
eta[1] -1.4e+00 1.1e-02 2.6e-01 -1.8e+00 -1.3e+00 -9.3e-01 606 17 1.0e+00
eta[2] 6.1e-01 7.5e-03 1.0e-01 4.4e-01 6.1e-01 7.8e-01 177 5.0 1.0e+00
## parallel summary output (not that cmdstan reports total CPU time, not walltime)
Warmup took (72) seconds, 1.2 minutes total
Sampling took (56) seconds, 56 seconds total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -1.2e+04 3.4e+00 2.3e+01 -1.2e+04 -1.2e+04 -1.2e+04 44 0.78 1.0e+00
accept_stat__ 9.3e-01 2.3e-03 8.2e-02 7.6e-01 9.7e-01 1.0e+00 1234 22 1.0e+00
stepsize__ 6.7e-02 1.9e-16 1.9e-16 6.7e-02 6.7e-02 6.7e-02 1.0 0.018 1.0e+00
treedepth__ 5.0e+00 8.3e-16 1.9e-14 5.0e+00 5.0e+00 5.0e+00 500 8.9 1.0e+00
n_leapfrog__ 6.3e+01 1.1e-14 2.5e-13 6.3e+01 6.3e+01 6.3e+01 500 8.9 1.0e+00
divergent__ 0.0e+00 0.0e+00 0.0e+00 0.0e+00 0.0e+00 0.0e+00 500 8.9 nan
energy__ 1.2e+04 3.6e+00 2.8e+01 1.2e+04 1.2e+04 1.2e+04 61 1.1 1.0e+00
log_lambda 1.6e+00 5.8e-03 4.2e-02 1.6e+00 1.6e+00 1.7e+00 54 0.96 1.0e+00
tau 1.2e+00 6.7e-03 3.7e-02 1.1e+00 1.2e+00 1.2e+00 31 0.55 1.0e+00
eta[1] -1.4e+00 8.9e-03 2.5e-01 -1.8e+00 -1.3e+00 -9.8e-01 760 14 1.0e+00
eta[2] 6.1e-01 6.2e-03 9.2e-02 4.6e-01 6.1e-01 7.6e-01 219 3.9 1.0e+00
The prototype runs at the moment on the basis of the released 2.20 dynamic HMC (so without the newest changes to the sampler) - but these results are promising to me. The complexity of the parallelization is managable given we use the TBB which I can use to nicely abstract away the parallelization through the use of a dependency flow graph, see here. Thus, with some (hopefully) light refactoring of the existing code base we should be able to add this feature in a way which minimizes maintenance… ahh… and this stuff runs without the need to refactor the existing AD; it only requires the uptake of the Intel TBB so that we can use a dependency flow graph which is automatically parallelized by the TBB.
Below is also a normalized density histogram of the speedup distribution at each treedepth.
Best,
Sebastian
parallel-dynamic_hmc-maxspeedup.R (1.5 KB)