Multi-chain vs single-chain

I am having some trouble understanding the point of having multiple chains in Stan. According to this link here: Multithreading comes to Stan which I got from this thread: Hardware Advice, I got the idea that multi-chain sampling was used for faster convergence and that using more than 4 chains does not yield any significant performance gain.

However, from this post: Multicore Speedups are different between models - #8 by betanalpha, it is stated that mutliple chains are used for comparison with R-hat to see if there are divergences among chains to figure out the validity of the chains and decide if the sampler can handle a given model.

So, I want to know why does Stan use 4 chains by default. Is there a benefit to using multiple chains over single chain ?

Thanks

The first link you provide is mostly about multithreading within chains, where you are using multithreaded processing within a single chain. This is a way to speed up the evaluation of individual chains, but is separate from the process of running multiple chains. In particular, the graph in the post is about the number of threads per chain, showing a plateau when the author runs 4 threads per chain for 4 chains; which is rightfully explains as plateauing since his computer has 16 cores.

There are a number of benefits to running multiple chains, but I’ll just share a primary one:

Multiple chains enable diagnostics such as R-hat to allow us to probe the validity of the sampler. The goal of any MCMC method is to generate samples from the target distribution \pi(\theta); however, we can’t ever really know in general whether our Markov chain has reached stationarity, i.e. the samples are from \pi(\theta). We have some theorems in our back pocket that tell us the samples will be asymptotically valid, but we are never in that asymptotic regime so we have to use heuristics and diagnostics to justify whether our samples are from \pi(\theta).

One method to do this is to run multiple chains and observe whether or not they have mixed. We initialize multiple chains from different points and if after some time it looks like all of the chains are generating samples from the same distribution, we can take this as a signal that all the chains have reached the same stationary distribution. R-hat measures this mixing behavior, where \hat{R} \approx 1 provides evidence in favor of mixing. Granted, in the presence of multimodal distributions even this is not necessarily a guarantee, but it’s a good start.

So by using multiple chains, Stan gives us additional information that let’s us better make the decision to trust the samples or not.

As to why the default is 4 chains, I imagine there’s some history there I’m unaware of. At the very least 4 core CPUs are fairly standard in most machines, so it seems like a reasonable baseline.

5 Likes

As the number of chains is available as an argument in cmdstan, the user is allowed to specify how many chains is desirable. How do you know what is the optimal number of chains to use ?

Is it also fair to say that the more chains you run, the better your confidence that the sampler can handle the model ?

Absent any resource constraints, it is certainly the case that having more chains makes for more accurate/informative diagnostics, but it’s a non-linear function with diminishing returns.

There really isn’t a single answer to this, as what one “should” do really varies by contextual factors such as time/cpu constraints, real-world consequences of inferential errors, model complexity, etc.

When models warmup & sample quickly, then there’s little cost to doing more chains; I have a system with 8 physical cores, so I usually do 8 chains and if warmup on each has achieved similar adaptation and the sample diagnostics all pass, I’m usually more than satisfied.

But occasionally I have more beastly models that take a considerable amount of time to compute per iteration, and I’ve been burned by trying to sample in parallel chains only to discover days later that the diagnostics fail, usually because there’s something subtle in my model specification that is off and needs re-thinking. For these scenarios, I’ve started doing a one-chain-at-a-time approach where I do as much within-chain parallelization as possible and make sure that a single chain can at least yield passing within-chain diagnostics before running a second chain, etc. I’m additionally working on live diagnostics during warmup & sampling (the latter being more straightforward) to permit even earlier detection of poor performance indicative of model misspecification. But given a series of such serial chains continuing to pass all diagnostics, I’d usually feel comfortable stopping after 8 or so chains, but that’s not based on any particularly rigorous determinations of time-accuracy trade-offs.

1 Like

Also, on:

When you use real data, the diagnostics can fail either because the sampler can’t handle the model (ex. multi-modal posteriors in the case of HMC), or because the model as specified doesn’t accurately reflect the essentials of the true process by which the data came to be, or both. To narrow focus to whether the sampler can handle the model, you can use the model to repeatedly generate fake data and sample as if it were real data, a procedure known as Simulation Based Calibration. If SBC passes yet your diagnostics on real data don’t, that’s a strong signal that your model is misspecified.

Thanks for the info @mike-lawrence ! I just have another question:

but it’s a non-linear function with diminishing returns.

What do you mean by this ?

it means that the more chains you use, the less the incremental advantage in evaluating mixing.

1 Like