I’m writing to notify the Stan developers and broader community about a new NUTS variant that may be of interest. This is a follow up to @Bob_Carpenter 's recent blog post with a reprex and some more details. The basic idea is to precondition (i.e., decorrelate and descale) the posterior using sparse matrix algebra prior to sampling with Stan’s NUTS algorithms. In the presence of high correlations or differences in marginal scales, this can substantially increase efficiency (minESS/t) over typical NUTS defaults. We call it sparse NUTS (SNUTS).
Stan does not have the required sparse infrastructure and so we implemented and tested SNUTS in a platform called Template Model Builder (TMB). Like Stan, TMB users write a function that calculates the unnormalized log target density and then it uses AD to calculate gradients. TMB utilizes sparsity in the data inputs, multivariate density evaluations, and the Laplace approximation to the marginal posterior. Marginalization is applied to arbitrary subsets of parameters specified by the user, but typically is just what we call “random effects” using the frequentist language adopted by TMB.
The central idea is the marginal posterior geometry of hierarchical models is much better behaved and so optimization is easy to do. Further, information about the global geometry is available at the conditional mode. We call this the “joint precision” matrix Q, where joint means the whole parameter space and precision meaning the inverse covariance of a multivariate normal. Q thus can sometimes approximate the global geometry of hierarchical posteriors and it is sparse due TMB’s ability to automatically detect conditional independence of parameters.
Estimating Q prior to MCMC sampling has some overhead (optimization, calculation of Q, testing for correlations, etc.), but has some distinct advantages:
- It can be used to approximate the posterior during early model development. We found this generally outperformed Pathfinder.
- It can be used to precondition the posterior if large correlations or marginal scales are found.
- It provides a way to initialize NUTS chains by drawing samples centered on the mode, e.g. from N(\hat{x}, Q^{-1}).
- 2 and 3 together generally eliminate the need for a long warmup with adaptation of a mass matrix. This is because the model is already descaled so adaptation of a diagonal mass matrix is not necessary. The warmup only needs to be long enough for chains to move to the typical set, which is fast due to informed initial values, and tune the step size. We found that 150 warmup iterations were sufficient in most cases studied.
All of these are possible without the need to invert Q which allows SNUTS to scale into very high dimensions that are simply not possible with a dense mass matrix. I show this below with a simple reprex.
Further details can be found in this preprint. We built an R package called SparseNUTS which implements this and use it to demonstrate SNUTS vs Stan using a very simple Poisson GLMM with iid site-level effects from simulated data. First I define the models in Stan and TMB. Here I use the RTMB interface to TMB which lets users write a function in plain R, which is used to determine the computational graph which is then executed in C++ (see details here). To be clear the model below is not executed in R, but rather TMB’s C++ backend.
# RTMB function
f <- function(pars){
getAll(pars,dat)
lp <-
# random effect prior
sum(dnorm(D, mean=0, sd=exp(logsigma), log=TRUE)) +
# data likelihood
sum(dpois(x=C, lambda=exp(logmu+D)[site], log=TRUE))
return(-lp) # TMB requires negative log posterior density
}
stancode <-'
data {
int<lower=0> nobs; // Number of observations (length of C)
int<lower=0> nsites; // Number of sites (length of D/muvec)
array[nobs] int C; // Data vector of counts
array[nobs] int site; // Index mapping observations to sites
}
parameters {
real logsigma; // hypervariance in log space
real logmu; // hypermean in log space
vector[nsites] D; // estimated log site means
}
model {
D ~ normal(0, exp(logsigma));
C ~ poisson_log(logmu + D[site]);
}
'
RTMB does not allow constraints in the parameter declaration and so for ease of comparison I wrote a matching Stan model. I also exclude Jacobian adjustments and priors for simplicity. This is just a model that runs fast but can exhibit sampling issues due to being poorly conditioned. Specifically, this model has a strong negative correlation with the logmu parameter and the site-level effects D. In higher dimesions logmu is precisely estimated so those correlations shrink, but there are a lot of small correlations.
I fit the model in cmdstanr (using Stan defaults and 4 parallel chains) and RTMB and SparseNUTS with increasing dimension (number of sites) while tracking min bulk ESS, wall time, and the mean post-warmup trajectory lengths. Compilation time was excluded for Stan and is non-existent for RTMB. I also ran a Stan model with a sum-to-zero constraint on D: sum_to_zero_vector[nsites] D; which mitigates the correlations but changes the model interpretation. Attached is a script to replicate this analysis which produces this plot.
My takeaways from this simple experiment:
- SNUTS is able to precondition and lower the post-warmup trajectory lengths compared to Stan. This highlights the potential of adding SNUTS to the Stan code base and is the main purpose of this post.
- SNUTS is slower than Stan on a per gradient basis. This is no surprise due to the optimized nature of Stan’s code, and the overhead of passing the objective and gradient functions in R when using SNUTS. In higher dimensions this overhead is less important.
- Both SNUTS and the sum-to-zero version of Stan can sample this model effectively to at least 16k sites, but with the latter at the cost of changing the model (see this post). The standard Stan model struggles to produce effective samples. This is a known issue with correlated posteriors. SNUTS is invariant to these types of global covariances without changing the population model and thus a more general solution.
In the manuscript we tested SNUTS on a wider set of hierarchical models and found high correlations and sparsity common (see Table 1 of the manuscript), so I believe these results would translate to a large class of models used by the Stan community. We also showed a few models where SNUTS fails because the global geometry is not well approximated by a normal distribution and thus preconditioning fails to help (e.g., an individual reponse theory model). We also tested the embedded Laplace approximation (ELA) approach on the case studies, which may be of interest to some but is ignored here.
SNUTS is not currently compatible with Stan because Stan lacks sparse matrix support, and specifically the automatic detection of conditional independence. A quick look at the discord topics shows that there has been a lot of discussion and interest in these topics (e.g., this thread). TMB+SNUTS demonstrates the types of advantages that could be had in Stan with adoption of sparse methods and marginalization with the Laplace approximation. I will say that my colleagues and I in the fisheries science field have found SNUTS to be a huge step forward in our statistical workflow.
Note that Kasper Kristensen is the lead developer and mastermind of TMB and RTMB and thus all credit goes to him. My role was to join TMB and the Stan samplers via the simple R package SparseNUTS. For now, I’m happy to answer questions about the approach and hope that this thread can serve as a place to discuss sparsity and SNUTS pros and cons in Stan.
Thanks,
Cole
snuts_vs_stan_glmm.R (5.7 KB)



