New preprint: Investigating the efficiency of marginalising over discrete parameters in Bayesian computations

Dear Stan community,

We just released a new preprint: Investigating the efficiency of marginalising over discrete parameters in Bayesian computations.

Folk wisdom in the Stan community suggests that fitting models with marginalisation can improve computational efficiency. In this work we looked for empirical evidence for this idea. Specifically, we:

  • Explored impact of marginalisation on computational efficiency for Gaussian mixture models and the Dawid–Skene model for categorical ratings
  • Implemented various marginalised and non-marginalised version of the models in JAGS and marginalised versions of the models in Stan
  • Compared the computational efficiency of the different inference approach across several simulation scenarios

Surprisingly, our results suggest that marginalisation on its own does not necessarily boost computational efficiency. We concluded that there is no simple answer to whether or not marginalisation is helpful.

We hope you enjoy reading this and we are happy to take any comments.

Wen Zhang (together with Jeffrey Pullin, Lyle Gurrin and Damjan Vukcevic)

11 Likes

Thanks for posting, this all looks really interesting!

I think including Stan in the comparisons here might not be entirely helpful for the conclusions of the paper, because there are too many other factors that differ beyond just the use of marginalisation. While there might be a performance/efficiency difference between the JAGS and Stan marginalised models, it’s not able to be determined whether this was due to:

  • Parameterisation - the JAGS David-Skewne models used a vector of zeros with a poisson likelihood to enforce a lower bound, where Stan implements this natively
  • Sampler - NUTS may just be more efficient with the same continuous likelihood than Gibbs/MH
  • Implementation - The Stan models used vectorised operations and the JAGS models used loops

Without more stringent controls on the individual differences between the models, it’s a bit difficult to draw conclusions around whether the benefits of marginalisation in Stan differ from JAGS.

Also, it looks like the Stan models are computing the log-likelihood twice (once in the model block and again in the generated quantities for post-processing) when JAGS only has to do it once. This might make comparisons in terms of walltime (i.e., using system.time()) a bit biased, since the timing for the Stan models includes more than just the sampling.

8 Likes

Nice preprint!

A pedantic point: the \hat{R} computed by coda::gelman.diag is not the same as the one computed by rstan::monitor which is in rstan:::Rhat (for a sufficiently recent version of rstan), so strictly these aren’t comparable (but I’m not sure it makes a huge difference here).

4 Likes

For consistency of methods/diagnostics between the Stan and JAGS models, the posterior package could be useful: GitHub - stan-dev/posterior: The posterior R package

4 Likes

I agree with the others; nice paper!

I’m no expert, but I think it might be useful to take a look at the tail ESS estimators from posterior as well. I’ve heard from @Bob_Carpenter and others that the marginalized parameterization might be particularly advantageous for exploring the tails of the continuous parameters.

Also FWIW, I think it would be useful to confirm explicitly that the marginalized and non-marginalized models are sampling from the same posterior. For example, you could combine chains from both models and confirm that r-hat doesn’t look bad.

6 Likes

Thanks for sharing, @wenkatezhang. I’m especially happy you made the code repo open access.

I have some alternative interpretations and suggestions for more precise evals.

It’s not just Stan, nor is it just folk wisdom in the larger MCMC community—there’s the Rao-Blackwell theorem.

As an example, we know that discrete sampling for small probabilities requires very many draws if you want to get any kind of reasonable relative error in estimates. For example, try to estimate theta = 0.01 by averaging i.i.d. draws y ~ bernoulli(theta)—you’ll need thousands of draws . Absolute error is fine (the answer is near zero), but relative error is going to take a lot of discrete draws. You can expect around 10% relative error with 10K iterations. And that’s only 0.01. Table 2 of the paper shows that the smallest discrete probability being tested is 0.20. This is the easiest possible case for sampling—making those mixture distributions more skewed would make this much harder to fit with discrete sampling.

You can see a textbook example of that in the section I wrote on changepoint models in the Stan User’s guide, where it allows you to compute tail statistics in expectation (Rao-Blackwell style) that would be intractable via MCMC.

Our main computational concern with marginalizing is that the marginal density will be poorly conditioned in a way that varies around the posterior. We’re specifically worried about introducing stiffness in the Hamiltonian dynamics (varying time scales leading to bad conditioning, and thus requiring small step sizes). We haven’t found that in most problems.

The evals that Zhang et al. do in the paper are very disadvantageous for full sampling. In section 2.3.1, the paper indicates the evals turn off all the conjugate samplers and GLM optimizations, which forces JAGS to use slice sampling for each conditional update. This is a huge hit in efficiency and accuracy.

Table 3 in the paper explains why there can be good ESS estimates for JAGS when Figure 1 and Figure 3 show R-hat indicates non-convergence. The Coda-based ESS is very optimistic and assumes convergence (i.e., R-hat = 0). It calculates an ESS estimate per chain and adds them. This systematically overestimates ESS (as measured by calibration tests of square error) when the chains have not converged (i.e., R-hat >> 1).

Stan’s more conservative ESS estimator (e.g., as found in the arviz package in Python or the posterior package in R) discounts ESS for non-convergence. Our estimators are much better calibrated with square error (the distribution of which has the MCMC standard error as a scale), which is what you want. Also, I believe coda is wrongly capped at ESS = max iterations; in principle, this can be higher with antithetical sampling like you find with NUTS. We very often get ESS values that are higher than the number of draws (and yes, it’s the right answer as measured by square error).

I would conclude from Figures 1 and Figures 3 is that JAGS has a hard time with convergence. This is consistent with my informal observations fitting the Dawid-Skene model with BUGS compared to marginalizing and fitting with HMC. It took fit time down from 24 hours in BUGS with poor stability to a stable 20 minutes in Stan.

The paper says in the conclusion section (3.3),

The stan model (which is marginalised) was more efficient than the JAGS models, highlighting the difference due to software engineering.

At least 10 years ago, Stan was coded much more efficiently than JAGS (Matt and I originally started rewriting JAGS to vectorize before we started working on Stan). But I think the bigger difference is due to algorithm—Stan uses HMC, which scales well in dimension, whereas JAGS uses Gibbs, which does not scale well in dimension. I think one of the confounding factors of this analysis is that JAGS is probably very inefficient at dealing with marginalizations unless they’ve come up with a better way to do it than the 1s and 0s tricks.

I also didn’t understand the point of computing speed independently of ESS. All that really matters is ESS/second after convergence if you want to know how well a system mixes. I think the even more important measure for most applications is time to min ESS = 100, but that’s even harder to measure stably.

The main conclusion section (3.3) ends with this:

the marginalised stan model had computational efficiency similar to jags-full shows that, once again, the software implementation is a more substantial factor in performance; in this case, the benefit was enough to counteract the seeming disadvantage of marginalization.

I don’t think it’s the software per se, it’s the sampler—HMC vs. Gibbs. Also, stan appears to be about twice as efficient as jags-full according to Figure 1 and Figure 3. Even so, these figures underestimate Stan’s advantage because of the overestimation bias of Coda’s ESS estimator.

In conclusion, I’d like to see this study redone with

  • Stan’s ESS estimator for both Gibbs and NUTS
  • more skew in the mixture probabilities
  • larger numbers of parameters
  • let JAGS do whatever it can to speed things up (should massively speed up jags-full, I suspect)

More params because HMC’s advantages grow as the number of parameters grows. Letting JAGS do whatever it can is a more fair comparison of conditional sampling approaches.

I’d also like to see it redone with more efficient Stan code. I created an issue in their repo with efficiency suggestions for the Stan programs.

14 Likes

Thank you all for your helpful comments!

See my responses (I am author on the paper) below:

@andrjohns

We are aware that there are many feature beyond marginalization which affect the performance of individual methods. We included Stan in the comparisons because we were interested in whether Stan’s (good) performance was worth the mathematical effort of marginalization.

Ultimately, the paper aims to study both questions which are practical (should I use Stan?, should I marginalize my model in JAGS?) and theoretical (does marginalization lead to more performant inference in general?, should future probabilistic programming language automatically marginalize certain models?). We should probably make these different types of questions more explicit in the paper.

Thank you for pointing out the generated quantities block in the Dawid Skene model - although I wrote it I had forgotten it existed… We will remove it to make the comparisons more fair.

@hhau

Thanks! Thanks for making that point, we were aware of the distinction but not sure how to address it. Using the posterior package as noted by @andrjohns should allow us to fix it. :)

@jsocolar

Thanks! Both adding in the tail ESS estimators and pooling chains across models are nice ideas we will try.

@Bob_Carpenter

I’m glad you appreciate the open access code :)

There is the Rao-Blackwell theorem but we (a strict subset of the authors of the current paper) are not convinced that the Rao-Blackwell theorem alone provides an iron-clad reason to marginalize, both because of the issues with ill-conditioned posteriors and limitations of the theory of marginalization (see our preprint describing the rater package, Section 6).

I agree though that marginalization should be particularly effective at estimating rare discrete events and that we should better assess this in our paper. (As an aside: it is not immediately clear to me what the better marginalized estimator of theta would be in the situation you describe, does one exist?)

The jags-full model in the manuscript is intended to have no restrictions on the samplers or other computational tricks JAGS can use. (I will double check this with the other authors). I think the text in Section 2.3.1 does not make the (intended) lack of restrictions clear.

Thanks for the details of the Coda ESS estimator! As mentioned above we will switch to using the posterior package across models in future versions of the manuscript.

We will tweak the text in section 3.3 to better highlight the impact of algorithm.

Thank you very much for the suggestions for making the code more efficient: we will implement them in future versions.

Thanks for your comments everyone!

6 Likes

Christian and Gareth have a nice paper that goes into the subtleties of the benefits of “Rao-Blackwellizing” target distributions, as well as the chaos in the meaning of term in the first place, in [2101.01011] Rao-Blackwellization in the MCMC era. Long story short Rao-Blackwellizing isn’t a universal improvement, but at the same time the exceptions are pretty unrealistic and often are limited to Gibbs samplers (where the conditionals of the marginal model are nastier than the conditionals of the full model). For methods like Hamiltonian Monte Carlo that work with the full model Rao-Blackwellization almost always improves the geometry of the target distribution so we there’s a multi-fold improvement due to working with marginalized expectands (the actual Rao-Blackwell part), smaller dimensionality of the target parameter spaces, and nicer geometry that allows for faster sampling.

1 Like

Hi all,

We have now updated the preprint, addressing many of the helpful comments we received in this thread.

In particular we have:

  • Updated the simulations to include more highly imbalanced simulation scenarios;
  • Consistently used the {posterior} package to calculate R hat etc. for all methods, and
  • Tightened the text and expanded on the motivations for this work.

Thanks once again for you helpful comments!

Best,
Jeffrey

4 Likes

Hi all,

Following on from @jeffreypullin, I’d like to add a bit more detail on some of the changes between the first and second versions of our preprint, and their (non-)impact on our conclusions:

  1. For the two-component Gaussian mixture model (Section 2.2.1), we made Dataset 2 more imbalanced (mixture proportions of 0.9 and 0.1, previously 0.7 and 0.3). We found that this imbalance didn’t have a noticeable impact on the relative performance of the models (see new paragraph at the bottom of page 9, and the updated Figure 2). The same set of models were still in a tie for best performance (as measured by the the time per minimum effective sample size) when comparing our two versions, and also when comparing the results for Data 1 and Data 2.

  2. As @jeffreypullin already mentioned, we switched to using the posterior package for all calculations of R-hat and effective sample size. The main thing we noticed with this change was that we got many fewer large R-hat values. Otherwise, the relative performance of the models was largely still the same, and thus our conclusions remain unchanged.

  3. For the Dawid-Skene model, we switched from using the rater R package to using a Stan model file directly, in order to add the several model tweaks and optimisations suggested by @Bob_Carpenter. I don’t actually recall how much of an impact those made on their own; the results in our preprint include both this change and the switch to using the posterior package. The overall result was similar to before: the JAGS (nor marginalised) model was just slightly ahead of Stan (marginalised).

Cheers,
Damjan

3 Likes