Cross-chain warmup adaptation using MPI

Thanks for joining testing.

This was expected at the moment. I now realize we should have warned about multimodal posteriors.

The plan is to diagnose multimodality and then advice user to disable cross-chain communication and warn about the difficulties of sampling multi-modal distributions. I added (maxrhat = max(mon[,'Rhat'])) to the script. You could also add Rhats to your report.

Thanks. This can be useful when we start to testing multimodality diagnostic.

EDIT: fixed the script and removed discussion of the problem here.

2 Likes

Thanks for taking this slowly! I’m very excited about making adaptation faster, but we can’t break things when we do it.

Is it possible to correct the signs of parameters in the model to make it identifiable? That’s a much simpler approach and tends to be better behaved computationally because it doesn’t run the risk of a single chain crossing modes (which can’t be easily sign corrected after the fact).

1 Like

@andrjohns the stepsizes are getting really small there. I wonder why that is. Thanks for sharing the model.

Right now the mpi warmup is doing warmup based on Rhat/ESS of lp__.

But yeah what Aki said. Our hope was to print a warning (much like the divergences/treedepths ones) so if this adaptation fails in a detectable way we make it easy to switch back to a fixed length warmup.

Yup. And so we test!

Yes very much so.

1 Like

By hurdle rate I meant (2) i.e. running for 100 adaptation, turning off adaptation and sampling for some N, calculate ESS, then turn back on adaptation for another 100 and looking at how the ESS changes over adaptation sizes. That feels like the right metric. Or you can look at rhat or what you will with the same scheme

Those are reasonable. My intent with suggesting the hurdle rate is to get an aggregate of those measures since it sounds like the goal here is to overall need less adaptation steps

1 Like

summary(based on @avehtari’s script) for some models from posteriordb, all based on target ESS=200 & same rng seed.

2 Likes

Could you add digits=2 or something for easier reading? I should add that to my script, too, but I’ve been bit busy with other things.

skimming over this, these numbers look very promising.

@yizhang How do you define “classic”? How many chains are being run?

updated by using plots.

changed wording to “regular”: 4 adaptive nuts chains.

2 Likes

The plots for arK-arK model raise some flag on ESS (it could be that min op picked different param but anyways) so I increased the target ESS:

2 Likes

To see the benefit of this parallel setup I took it to a ride. Still with arK-arK model I have MPI runs with num_chains=4, 16, 32, as well as a regular NUTS’ 4-chain run. In MPI runs I’m asking a lot from cross-chain warmup by setting target ESS=1000. Figure below shows that cross-chain runs achieve similar or better sampling performance with much less number of warmups. With 32 chains we only need 250 warmups, even with this particular large target ESS(with a lower target ESS=500 32 chains need only 100 warmups per chain)

5 Likes

In terms of absolute(wall) time(in seconds) saving, parallel warmup works best for models with expensive iterations. This is demonstrated on two models below. Model Radon has a large number of params, and model SIR involves ODE integration. All results are based on 4-chain run in parallel(for regular run that’s just embarrassingly parallel) and same rng seed. The cross-chain runs use default target ESS=200.

Since all cross-chain runs spent \le 400 iters on warmup, it’s natural to see significant time improvement, except one: cross-chain run actually is longer in dense_e metric run of SIR. To look closer, I printed the time each chain spent on cross-chain adaptation(rhat, ess, covar calculation, as well as data communication):

Iter chain 1 chain 2 chain 3 chain 4
100 2.4 0.0 0.9 5.8
200 0.0 5.3 4.2 5.0
300 0.0 0.2 0.2 0.2

Recall that cross-chain adaptation happens(by default) every 100 iters, so the above table shows at every adaptation, some chains have to wait for the others to finish, when the slowest chain’s ODE has a param sample that stresses the numerical integrator. This is common cause of blocking in parallel ODEs. The solution would be to make runs asymmetric: chains that finish keep running until all chains are ready for adaptation, but this would make chains have different number of warmups, and breaks down our interface.

2 Likes

Which parameter is it that has the lowest bulk_ess? If it isn’t lp__, is lp__ much larger?

I didn’t check but my guess is it’d be most likely lp__ got picked up.

The reason I was asking was I’m curious if this is the reason a low adaptation N_eff target didn’t work.

If lp__ had a higher N_eff than something else, then we’d expect to have to set the target N_eff higher to get to the point where things worked.

1 Like

Getting back to more testing of the algorithm on Torsten. One can get the branch here

git clone --recursive --branch cross_chain_warmup https://github.com/metrumresearchgroup/cmdstan.git

Below is the performance summary of a simple PK model using cross-chain warmup, with target ESS=400. For this model, current default target ESS=200 isn’t always sufficient.

3 Likes

With cross-chain on top of Torsten’s parallel functions, I’m able to do 2-level parallelism: cross-chains communicating during warmup, and within-chain parallel solution. Here I’m showing the Chemical reactions model performance(all run with 4 chains) solved by

  • regular stan run(4 independent chains),
  • 4-core cross-chain run(each chain solved by 1 core),
  • 8-core cross-chain run(each chain solved by 2 cores),
  • 16-core cross-chain run(each chain solved by 4 cores), and
  • 32-core cross-chain run(each chain solved by 8 cores).

Since the model involves a population of size 8, the within-chain parallelization evenly distributes the 8 subjects to 1, 2, 4, 8 cores. This setup improves speed in two levels:

  • cross-chain warmup automatically terminates at num_warmup=350. Below is ESS performance summary.
MPI nproc=4 regular.
warmup.leapfrogs 1.222100e+04 2.959900e+04
leapfrogs 1.362400e+04 1.407600e+04
mean.warmup.leapfrogs 3.491714e+01 2.959900e+01
mean.leapfrogs 2.724800e+01 2.815200e+01
min(bulk_ess/iter) 1.708000e+00 1.452000e+00
min(tail_ess/iter) 2.184000e+00 2.276000e+00
min(bulk_ess/leapfrog) 6.268350e-02 5.157715e-02
min(tail_ess/leapfrog) 8.015267e-02 8.084683e-02

@avehtari @Bob_Carpenter @billg @bbbales2

6 Likes

This is all very nice. Less leapfrogs to warmup, sampling just as efficient, and scaling across computers.

For the bulk_ess/iter numbers, how are those calculated? 1.7 effective samples per MCMC draw seems too high. Does that need divided by number of chains?

Yes.

1 Like

Indeed! Cutting number of leapfrog steps in half during warmup is essentially doubling its speed (or at least cutting its resource usage in half). But what’s more amazing is we seem to be getting better adaptation because the speedup’s more than you’d get from just doubling warmup speed, right? Or are these models heavily dominated by warmup?

Not really. For this model post-warmup sampling takes approximately same amount of time for regular & cross-chain runs(regular vs nproc=4 in the above plot). With additional cores the benefit of within-chain parallelization kicks in and run time gets further reduced for both warmup & sampling(nproc=8 & nproc=16 & nproc=32 in the above plot).