Multi-chain vs single-chain

I am having some trouble understanding the point of having multiple chains in Stan. According to this link here: Multithreading comes to Stan which I got from this thread: Hardware Advice, I got the idea that multi-chain sampling was used for faster convergence and that using more than 4 chains does not yield any significant performance gain.

However, from this post: Multicore Speedups are different between models - #8 by betanalpha, it is stated that mutliple chains are used for comparison with R-hat to see if there are divergences among chains to figure out the validity of the chains and decide if the sampler can handle a given model.

So, I want to know why does Stan use 4 chains by default. Is there a benefit to using multiple chains over single chain ?

Thanks

The first link you provide is mostly about multithreading within chains, where you are using multithreaded processing within a single chain. This is a way to speed up the evaluation of individual chains, but is separate from the process of running multiple chains. In particular, the graph in the post is about the number of threads per chain, showing a plateau when the author runs 4 threads per chain for 4 chains; which is rightfully explains as plateauing since his computer has 16 cores.

There are a number of benefits to running multiple chains, but I’ll just share a primary one:

Multiple chains enable diagnostics such as R-hat to allow us to probe the validity of the sampler. The goal of any MCMC method is to generate samples from the target distribution \pi(\theta); however, we can’t ever really know in general whether our Markov chain has reached stationarity, i.e. the samples are from \pi(\theta). We have some theorems in our back pocket that tell us the samples will be asymptotically valid, but we are never in that asymptotic regime so we have to use heuristics and diagnostics to justify whether our samples are from \pi(\theta).

One method to do this is to run multiple chains and observe whether or not they have mixed. We initialize multiple chains from different points and if after some time it looks like all of the chains are generating samples from the same distribution, we can take this as a signal that all the chains have reached the same stationary distribution. R-hat measures this mixing behavior, where \hat{R} \approx 1 provides evidence in favor of mixing. Granted, in the presence of multimodal distributions even this is not necessarily a guarantee, but it’s a good start.

So by using multiple chains, Stan gives us additional information that let’s us better make the decision to trust the samples or not.

As to why the default is 4 chains, I imagine there’s some history there I’m unaware of. At the very least 4 core CPUs are fairly standard in most machines, so it seems like a reasonable baseline.

6 Likes

As the number of chains is available as an argument in cmdstan, the user is allowed to specify how many chains is desirable. How do you know what is the optimal number of chains to use ?

Is it also fair to say that the more chains you run, the better your confidence that the sampler can handle the model ?

Absent any resource constraints, it is certainly the case that having more chains makes for more accurate/informative diagnostics, but it’s a non-linear function with diminishing returns.

There really isn’t a single answer to this, as what one “should” do really varies by contextual factors such as time/cpu constraints, real-world consequences of inferential errors, model complexity, etc.

When models warmup & sample quickly, then there’s little cost to doing more chains; I have a system with 8 physical cores, so I usually do 8 chains and if warmup on each has achieved similar adaptation and the sample diagnostics all pass, I’m usually more than satisfied.

But occasionally I have more beastly models that take a considerable amount of time to compute per iteration, and I’ve been burned by trying to sample in parallel chains only to discover days later that the diagnostics fail, usually because there’s something subtle in my model specification that is off and needs re-thinking. For these scenarios, I’ve started doing a one-chain-at-a-time approach where I do as much within-chain parallelization as possible and make sure that a single chain can at least yield passing within-chain diagnostics before running a second chain, etc. I’m additionally working on live diagnostics during warmup & sampling (the latter being more straightforward) to permit even earlier detection of poor performance indicative of model misspecification. But given a series of such serial chains continuing to pass all diagnostics, I’d usually feel comfortable stopping after 8 or so chains, but that’s not based on any particularly rigorous determinations of time-accuracy trade-offs.

1 Like

Also, on:

When you use real data, the diagnostics can fail either because the sampler can’t handle the model (ex. multi-modal posteriors in the case of HMC), or because the model as specified doesn’t accurately reflect the essentials of the true process by which the data came to be, or both. To narrow focus to whether the sampler can handle the model, you can use the model to repeatedly generate fake data and sample as if it were real data, a procedure known as Simulation Based Calibration. If SBC passes yet your diagnostics on real data don’t, that’s a strong signal that your model is misspecified.

Thanks for the info @mike-lawrence ! I just have another question:

but it’s a non-linear function with diminishing returns.

What do you mean by this ?

it means that the more chains you use, the less the incremental advantage in evaluating mixing.

1 Like

In order to understand an (approximately) optimal configuration for Markov chain Monte Carlo we really have to understand what Markov chain Monte Carlo is actually doing. Markov chain Monte Carlo uses the sequential states

\{\theta_{1}, \ldots, \theta_{n}, \ldots, \theta_{N} \}

within a Markov chain to construct Markov chain Monte Carlo estimators of target expectation values,

E_{\pi}[f] = \int \mathrm{d} \theta \, \pi(\theta) \, f(\theta) \approx \frac{1}{N + 1} \sum_{n = 0}^{N} f(\theta_n) = \hat{f}.

In order to quantify the practical utility of a Markov chain Monte Carlo estimator we need to quantify how well \hat{f} approximates E_{\pi}[f], although this is much easier said than done. Without additional assumptions there’s not much we can do to quantify this error, but under nice conditions where a central limit theorem holds we can quantify the error using the empirical autocorrelations between the evaluations f(\theta_n) and the total length of the Markov chain. In particular in these specific circumstances the longer the Markov chain the smaller the estimator errors will be – doubling the length of a Markov chain will decreases the error by a factor of \sqrt{2}.

At the same time the error quantification allowed by a central limit theorem allows us to combine Markov chain Monte Carlo estimators derived from separate Markov chains. Provided that each Markov chain is sufficiently well-behaved then combining the estimators from C chains will decrease the estimator error by a factor of \sqrt{C}.

To review – under ideal conditions we have two ways to decrease the Markov chain Monte Carlo estimator error. We can run one long Markov chain or multiple short Markov chains. Which one is better? There is a long history of people arguing both sides; unfortunately those arguments are largely tainted by people not being clear about their assumptions and hence talking past each other.

Under ideal conditions and a fixed computational resource there shouldn’t be any strong difference between the estimator error from one long Markov chain of length N or C shorter Marko chains of length N / C. In practice, however, there are some complications.

For example if the Markov chains are too short then we also have to account for an initialization bias. This can be avoided by discarding the initial states from each Markov chain at the cost of some computational overhead. The resulting overhead, however, will be much smaller in the one long Markov chain scenario than in the many smaller Markov chains scenario. For example if we have to discard the first W states then the overhead for the long chain will be W / N while the overhead for the ensemble of shorter chains will be C \cdot W / N. In other words the one long Markov chain will be more performant. The overhead becomes even more considerable when we have to take into account adaptation of the Markov transition.

On the other hand because the shorter Markov chains are independent they can be run in parallel, and hence take advantage of parallel computing resources that the one long Markov chain might not be able to. Even if the overhead is larger the speedup from parallelization can easily make the ensemble approach more performant. That said we can take parallelization only so far – each of the individual Markov chains still have to be long enough to avoid an initialization bias which limits just how much we can distribute the computation.

Of course all of these considerations hold only under ideal circumstances which can’t be taken for granted in practice. One huge benefit of running multiple Markov chains is that the ensemble is sensitive to violations of those ideal circumstances. If everything is nice enough than the terminal states (i.e. everything but the initial states that we remove per the discussion above) of all of the Markov chains should all behave the same, and any inconsistencies identify failures of those ideal circumstances and doubt in any empirical error quantification.

So when the ideal circumstances can’t be taken for granted and we have access to parallelization resources then running at least a few Markov chains will tend to lead to more robust if not more performant Markov chain Monte Carlo estimation. With Stan’s default Markov chain configuration running more than ten Markov chains is unlikely to offer much benefit unless the target distribution is particularly nasty. For example if there are more than ten modes then more than ten Markov chains with sufficiently diffuse initializations may be needed to even find all of the modes.

At the beginning of Stan’s development multiple cores were becoming common in commodity hardware, and most newer computers were shipping with at least four physical cores. This is what motivated the four Markov chain default that persists to this day.

3 Likes