Thanks for sharing, @wenkatezhang. I’m especially happy you made the code repo open access.
I have some alternative interpretations and suggestions for more precise evals.
It’s not just Stan, nor is it just folk wisdom in the larger MCMC community—there’s the Rao-Blackwell theorem.
As an example, we know that discrete sampling for small probabilities requires very many draws if you want to get any kind of reasonable relative error in estimates. For example, try to estimate theta = 0.01 by averaging i.i.d. draws y ~ bernoulli(theta)—you’ll need thousands of draws . Absolute error is fine (the answer is near zero), but relative error is going to take a lot of discrete draws. You can expect around 10% relative error with 10K iterations. And that’s only 0.01. Table 2 of the paper shows that the smallest discrete probability being tested is 0.20. This is the easiest possible case for sampling—making those mixture distributions more skewed would make this much harder to fit with discrete sampling.
You can see a textbook example of that in the section I wrote on changepoint models in the Stan User’s guide, where it allows you to compute tail statistics in expectation (Rao-Blackwell style) that would be intractable via MCMC.
Our main computational concern with marginalizing is that the marginal density will be poorly conditioned in a way that varies around the posterior. We’re specifically worried about introducing stiffness in the Hamiltonian dynamics (varying time scales leading to bad conditioning, and thus requiring small step sizes). We haven’t found that in most problems.
The evals that Zhang et al. do in the paper are very disadvantageous for full sampling. In section 2.3.1, the paper indicates the evals turn off all the conjugate samplers and GLM optimizations, which forces JAGS to use slice sampling for each conditional update. This is a huge hit in efficiency and accuracy.
Table 3 in the paper explains why there can be good ESS estimates for JAGS when Figure 1 and Figure 3 show R-hat indicates non-convergence. The Coda-based ESS is very optimistic and assumes convergence (i.e., R-hat = 0). It calculates an ESS estimate per chain and adds them. This systematically overestimates ESS (as measured by calibration tests of square error) when the chains have not converged (i.e., R-hat >> 1).
Stan’s more conservative ESS estimator (e.g., as found in the arviz package in Python or the posterior package in R) discounts ESS for non-convergence. Our estimators are much better calibrated with square error (the distribution of which has the MCMC standard error as a scale), which is what you want. Also, I believe coda is wrongly capped at ESS = max iterations; in principle, this can be higher with antithetical sampling like you find with NUTS. We very often get ESS values that are higher than the number of draws (and yes, it’s the right answer as measured by square error).
I would conclude from Figures 1 and Figures 3 is that JAGS has a hard time with convergence. This is consistent with my informal observations fitting the Dawid-Skene model with BUGS compared to marginalizing and fitting with HMC. It took fit time down from 24 hours in BUGS with poor stability to a stable 20 minutes in Stan.
The paper says in the conclusion section (3.3),
The stan model (which is marginalised) was more efficient than the JAGS models, highlighting the difference due to software engineering.
At least 10 years ago, Stan was coded much more efficiently than JAGS (Matt and I originally started rewriting JAGS to vectorize before we started working on Stan). But I think the bigger difference is due to algorithm—Stan uses HMC, which scales well in dimension, whereas JAGS uses Gibbs, which does not scale well in dimension. I think one of the confounding factors of this analysis is that JAGS is probably very inefficient at dealing with marginalizations unless they’ve come up with a better way to do it than the 1s and 0s tricks.
I also didn’t understand the point of computing speed independently of ESS. All that really matters is ESS/second after convergence if you want to know how well a system mixes. I think the even more important measure for most applications is time to min ESS = 100, but that’s even harder to measure stably.
The main conclusion section (3.3) ends with this:
the marginalised stan model had computational efficiency similar to jags-full shows that, once again, the software implementation is a more substantial factor in performance; in this case, the benefit was enough to counteract the seeming disadvantage of marginalization.
I don’t think it’s the software per se, it’s the sampler—HMC vs. Gibbs. Also, stan appears to be about twice as efficient as jags-full according to Figure 1 and Figure 3. Even so, these figures underestimate Stan’s advantage because of the overestimation bias of Coda’s ESS estimator.
In conclusion, I’d like to see this study redone with
- Stan’s ESS estimator for both Gibbs and NUTS
- more skew in the mixture probabilities
- larger numbers of parameters
- let JAGS do whatever it can to speed things up (should massively speed up jags-full, I suspect)
More params because HMC’s advantages grow as the number of parameters grows. Letting JAGS do whatever it can is a more fair comparison of conditional sampling approaches.
I’d also like to see it redone with more efficient Stan code. I created an issue in their repo with efficiency suggestions for the Stan programs.