Sparse NUTS: preconditioning with sparse matrix operations

I’m writing to notify the Stan developers and broader community about a new NUTS variant that may be of interest. This is a follow up to @Bob_Carpenter 's recent blog post with a reprex and some more details. The basic idea is to precondition (i.e., decorrelate and descale) the posterior using sparse matrix algebra prior to sampling with Stan’s NUTS algorithms. In the presence of high correlations or differences in marginal scales, this can substantially increase efficiency (minESS/t) over typical NUTS defaults. We call it sparse NUTS (SNUTS).

Stan does not have the required sparse infrastructure and so we implemented and tested SNUTS in a platform called Template Model Builder (TMB). Like Stan, TMB users write a function that calculates the unnormalized log target density and then it uses AD to calculate gradients. TMB utilizes sparsity in the data inputs, multivariate density evaluations, and the Laplace approximation to the marginal posterior. Marginalization is applied to arbitrary subsets of parameters specified by the user, but typically is just what we call “random effects” using the frequentist language adopted by TMB.

The central idea is the marginal posterior geometry of hierarchical models is much better behaved and so optimization is easy to do. Further, information about the global geometry is available at the conditional mode. We call this the “joint precision” matrix Q, where joint means the whole parameter space and precision meaning the inverse covariance of a multivariate normal. Q thus can sometimes approximate the global geometry of hierarchical posteriors and it is sparse due TMB’s ability to automatically detect conditional independence of parameters.

Estimating Q prior to MCMC sampling has some overhead (optimization, calculation of Q, testing for correlations, etc.), but has some distinct advantages:

  1. It can be used to approximate the posterior during early model development. We found this generally outperformed Pathfinder.
  2. It can be used to precondition the posterior if large correlations or marginal scales are found.
  3. It provides a way to initialize NUTS chains by drawing samples centered on the mode, e.g. from N(\hat{x}, Q^{-1}).
  4. 2 and 3 together generally eliminate the need for a long warmup with adaptation of a mass matrix. This is because the model is already descaled so adaptation of a diagonal mass matrix is not necessary. The warmup only needs to be long enough for chains to move to the typical set, which is fast due to informed initial values, and tune the step size. We found that 150 warmup iterations were sufficient in most cases studied.

All of these are possible without the need to invert Q which allows SNUTS to scale into very high dimensions that are simply not possible with a dense mass matrix. I show this below with a simple reprex.

Further details can be found in this preprint. We built an R package called SparseNUTS which implements this and use it to demonstrate SNUTS vs Stan using a very simple Poisson GLMM with iid site-level effects from simulated data. First I define the models in Stan and TMB. Here I use the RTMB interface to TMB which lets users write a function in plain R, which is used to determine the computational graph which is then executed in C++ (see details here). To be clear the model below is not executed in R, but rather TMB’s C++ backend.

# RTMB function
f <- function(pars){
  getAll(pars,dat)
  lp <-
    # random effect prior
    sum(dnorm(D, mean=0, sd=exp(logsigma), log=TRUE)) +
    # data likelihood
    sum(dpois(x=C, lambda=exp(logmu+D)[site], log=TRUE))
  return(-lp) # TMB requires negative log posterior density
}

stancode <-'
data {
  int<lower=0> nobs;          // Number of observations (length of C)
  int<lower=0> nsites;        // Number of sites (length of D/muvec)
  array[nobs] int C;          // Data vector of counts
  array[nobs] int site;       // Index mapping observations to sites
}
parameters {
  real logsigma;              // hypervariance in log space
  real logmu;                 // hypermean in log space
  vector[nsites] D;           // estimated log site means
}
model {
  D ~ normal(0, exp(logsigma));
  C ~ poisson_log(logmu + D[site]);
}

'

RTMB does not allow constraints in the parameter declaration and so for ease of comparison I wrote a matching Stan model. I also exclude Jacobian adjustments and priors for simplicity. This is just a model that runs fast but can exhibit sampling issues due to being poorly conditioned. Specifically, this model has a strong negative correlation with the logmu parameter and the site-level effects D. In higher dimesions logmu is precisely estimated so those correlations shrink, but there are a lot of small correlations.

I fit the model in cmdstanr (using Stan defaults and 4 parallel chains) and RTMB and SparseNUTS with increasing dimension (number of sites) while tracking min bulk ESS, wall time, and the mean post-warmup trajectory lengths. Compilation time was excluded for Stan and is non-existent for RTMB. I also ran a Stan model with a sum-to-zero constraint on D: sum_to_zero_vector[nsites] D; which mitigates the correlations but changes the model interpretation. Attached is a script to replicate this analysis which produces this plot.

My takeaways from this simple experiment:

  1. SNUTS is able to precondition and lower the post-warmup trajectory lengths compared to Stan. This highlights the potential of adding SNUTS to the Stan code base and is the main purpose of this post.
  2. SNUTS is slower than Stan on a per gradient basis. This is no surprise due to the optimized nature of Stan’s code, and the overhead of passing the objective and gradient functions in R when using SNUTS. In higher dimensions this overhead is less important.
  3. Both SNUTS and the sum-to-zero version of Stan can sample this model effectively to at least 16k sites, but with the latter at the cost of changing the model (see this post). The standard Stan model struggles to produce effective samples. This is a known issue with correlated posteriors. SNUTS is invariant to these types of global covariances without changing the population model and thus a more general solution.

In the manuscript we tested SNUTS on a wider set of hierarchical models and found high correlations and sparsity common (see Table 1 of the manuscript), so I believe these results would translate to a large class of models used by the Stan community. We also showed a few models where SNUTS fails because the global geometry is not well approximated by a normal distribution and thus preconditioning fails to help (e.g., an individual reponse theory model). We also tested the embedded Laplace approximation (ELA) approach on the case studies, which may be of interest to some but is ignored here.

SNUTS is not currently compatible with Stan because Stan lacks sparse matrix support, and specifically the automatic detection of conditional independence. A quick look at the discord topics shows that there has been a lot of discussion and interest in these topics (e.g., this thread). TMB+SNUTS demonstrates the types of advantages that could be had in Stan with adoption of sparse methods and marginalization with the Laplace approximation. I will say that my colleagues and I in the fisheries science field have found SNUTS to be a huge step forward in our statistical workflow.

Note that Kasper Kristensen is the lead developer and mastermind of TMB and RTMB and thus all credit goes to him. My role was to join TMB and the Stan samplers via the simple R package SparseNUTS. For now, I’m happy to answer questions about the approach and hope that this thread can serve as a place to discuss sparsity and SNUTS pros and cons in Stan.

Thanks,

Cole

snuts_vs_stan_glmm.R (5.7 KB)

That looks really cool! Making use of more structural information about the model (like conditional independence) in the sampler sounds like a great way to improve performance long term.

I usually work with pymc instead of stan, where we have a lot of that information around in some way, but I don’t think we really make use of that possibility.

Out of curiosity I ran your benchmark script, but also included the low rank mass matrix adaptation from nutpie:

I had trouble with the multithreading in sparseNUTS though, so I disabled it. I guess that means that sparseNUTS looks about 4 times less efficient in this as it actually should be?

Looks like nutpie ends up faster when the model is smaller than a few thousand parameters, and then sparseNUTS starts to take over slightly? I think that makes sense, because while the runtime doesn’t grow that much with the low rank mass matrix, the space of covariance matrices really does grow, so the structure enforced in sparseNUTS starts to really help with the estimation? I’ll have to try a sparse version of the nutpie mass matrix adaptation at some point as well… :-)

As a small aside: You said that the zero-sum-normal parametrization changes the model. This is true if you do it by just replacing the prior. It is really not difficult to do it in a way that we keep the exact same model. ( Proper use of sum_to_zero_vector in nested multilevel models - #17 by aseyboldt ). This also generalizes to models with more levels, interaction effects and correlated coefficients.

script.R (9.0 KB)

@aseyboldt Thanks for the feedback. That’s a good idea to try it on nutpie. I haven’t used that actually but Bob keeps raving about it so I should have thought to do it.

Could you please file an issue about the parallel failure here with details about your platform?

The efficiency for nutpie is not comparable here because you don’t monitor lp__ which we do for Stan and SNUTS. For SNUTS it lp__ is always the slowest mixing parameter. We both have anticorrelation which leads to the bulk ESS being higher than the nominal iterations (ESS>4000) in both cases. Could you rerun it with lp__ being monitored too? I wonder if switching to ess_tail would also make more sense? That may be getting into the weeds a bit when the main point is about the algorithm of using new information to improve sampling.

Fair point about the sum-to-zero approach. I’m not an expert. I do find it cleaner to just write out the model in standard notation and have the correlations handled by the preconditioner.

@monnahc and I have discussed this, but I’ll share here. I agree that we’d ideally just have a way to write down any proper posterior density and fit it. Centered or non-centered, sum-to-zero or unconstrained, whatever.

I do think in cases where the varying effects are exhaustive, the sum-to-zero approach makes more conceptual sense. For example, if there are five levels for age (different age bands), then you only have 4 degrees of freedom if you also have an intercept. The traditional approach is to pin one value to zero to identify, but sum-to-zero is nicer in that it’s symmetric if you want to include a prior.

Consider a two-level case, like sex being male/female. You could have two varying effects, one for male and one for female, but that’s not identified if you also have an intercept. You can identify with a prior as Cole prefers, but that’s inefficient and you have to choose reasonable priors. The traditional approach is to set one level’s effect to zero (e.g., male), which works out to be identical to using sex as a fixed effect with a covariate of 0 for male and 1 for female. An alternative is to use sum-to-zero, where you get the male effect being the negation of the female effect. Provides the same degree of identification, but now it’s symmetric.

Where this is problematic is when it’s open ended. Like we have varying effects for 100 athletes or 8 schools or 15 clinical trials and we’re really interested in the larger population. Then sum-to-zero alone doesn’t really make sense without finagling it to allow generation of new groups. @andrewgelman said that was possible, and I keep meaning to ask him how, because I don’t see it. Maybe @aseyboldt knows.

@monnahc Good point about the logp, I’ll run that tomorrow.

@Bob_Carpenter The post I linked above had an example of how to use the ZeroSumNormal strictly as a reparametrization.

If we just write the effect as a normal distribution effect ~ normal(0, sigma), we measure the effect size relative to the population mean. But we can decompose this into an effect relative to the subpopulation mean mean(effect) of the levels that we have in our dataset, and the scalar subpopulation mean itself:

subpop_effect[i] ~ zero_sum_normal(sigma)
subpop_mean ~ normal(0, sigma / sqrt(num_effects))

The old effect is then just effect[i] = subpop_effect[i] + subpop_mean.

And zero_sum_normal is the MVNormal that has eigenvalues sigma^2 in all directions except (1,1,1…), in which the variance is zero.

And we can get rid of subpop_mean by just absorbing it into the intercept (or another higher level coefficient).

And yes, I also think that there are many cases where we actually want to know subpop_effect, not effect. But if we can sample the model written in terms of effect, we can always just compute subpop_effect[i] = effect[i] - mean(effect) anyway. The other direction isn’t exactly difficult, but can get quite annoying if you have multiple levels or interaction effects.

Another way of thinking about this is that as long as we just add up linear functions of normal distributions, we can always think of it as one big multivariate normal distribution.

I guess I really should write that down properly at some point. Also for the case with interactions and correlated coefficients.

@aseyboldt If you send me an updated script with logp tracked I can run on my end and update the benchmark.

Note that SNUTS can deal with correlations/scales caused by other things too. For instance the ‘kilpisjarvi’ model in posteriordb has three parameters (and is not hierarchical) but has a max pairwise correlation of 0.9999887 and difference in marginal scale of 3982.5. These extremes are no problem at all for SNUTS. It samples hundreds of times more efficiently because the trajectories are a hundred times shorter than NUTS. See the attached figure for the main results from the paper.

Many of these models from our preprint likely aren’t feasible in Stan (e.g., SPDE ones). If we wanted to do a more thorough comparison of Stan, nutpie and SNUTS we could port more posteriordb models over. I only did ‘diamonds’, ‘gp_pois_regr’ ‘kilpisjarvi’, and ‘radon’.

fig6_perf_mods.pdf (31.8 KB)

Sorry, I was traveling for a few days.

Here is the updated plot that includes the lp ess (script attached)

I’d also be interested to see what this looks like for more models. I would expect nutpie to do decently well in low dim problems, but as the dimensionality increases, I think adding more structure could really help.

script.R (9.4 KB)

Thanks for doing it, it’s much clearer now. It motivates me to keep working on minimizing the overhead of SNUTS. I also thought about adding compile time, which doesn’t exist for RTMB. In low dimensions compiling often takes longer than sampling so in that sense SNUTS would do very well too.

Do you have some bigger models in mind? In my previous post I named the four from posteriordb I ported over, but they’re not that big. I couldn’t find a summary or easy way to summarize all the posteriors in the package. Would be interesting to pick some with high dimension and high condition factor. If you have some in mind I could write them in RTMB and we could try it out. Or if you can do SPDE models in nutpie then we could do those too, as they scale really easy to dimension and have complex correlations.

Not sure if this of any help, but any models in particular you are looking for?

I think if this is particularly applicable to hierarchical models, it might be easy to just generate a bunch of stan code for hierarchical (multilevel models) using brms to test, to see what happens. This might be fast.

There’s a boatload on my computer, but some are just simple intermediate models building up to a final one that’s proprietary, there’s some of the GP models transferred from GPStuff years ago (it’s all public on my website) or you could scrape the documentation. I just skimmed the paper. I don’t think anyone’s added to posterior DB for a while. Adding models with GP priors might be a waste of resources because it would take longer. No idea.

SNUTS can’t simply run all models in posteriordb, because it needs extra information that’s not directly visible from a stan model, so it needs some recoding.
But apart from that concrete question, I think it would be incredibly valuable to add more models to posteriordb. It has a lot of small models, and a few very big but close to intractable models. Adding more big, but more or less still approachable ones as might come out of brms would help a lot.

Probably easy to add, lots of models, I think no one has uploaded any models. I’m not funded to do this, but I’m working on C++, this is more beneficial for my skills and career. Contributed to posteriordb is just annoying, and would require me re-running data simulations so that the models could run, but I could copy paste some models for now. There’s some here:

For BRMS I was thinking, just easy to automate the implementation of many hierarchical linear models, to test it out your current research if you want to quickly test it out on a bunch of hierarchical models.

Stan doesn’t have the sparse machinery of TMB so it does require recoding so we can’t just grab brms models. One thing we could try is to use a TMB GLMM package (glmmTMB) to approximate brms, but the parameterizations and priors would be different so that’s tough too. One of the points of this thread is to be a place to discuss whether it would be worth the very large effort to add it to the Stan source code so we could compare NUTS vs SNUTS on all models.

So that makes these comparisons harder to do since we either have to convert TMB models to Stan or Stan models to TMB. I think the latter makes more sense on this forumn. Note that they don’t have to be large. Even small models with really high condition factors are massively improved with SNUTS. E.g., the kilpisjarvi and diamonds models are relatively simple but are terribly conditioned. TMB and SNUTS have no issues though because of the linear nature of them.

Here are the RTMB ports of the diamonds and kilpisjarvi models (Stan equivalents here and here):

kilpisjarvi_dat <- readRDS('models/kilpisjarvi/dat.RDS')
pars <- list(alpha=1, beta=1, logsigma=0)
f <- function(pars){
  require(RTMB)
  getAll(pars,kilpisjarvi_dat)
  sigma <- exp(logsigma)
  lp <-
    dnorm(alpha, pmualpha, psalpha, log=TRUE) + # hyper prior
    dnorm(beta, pmubeta, psbeta, log=TRUE) +    # hyper prior
    sum(dnorm(y, alpha+beta*x, sigma, log=TRUE)) + # likelihood
    logsigma # jacobian adjustment
  REPORT(lp)
  return(-lp)
}

For diamonds I preprocess the data in R which replaces the ‘transformed data’ section in Stan.

f <- function(pars){
  require(RTMB)
  getAll(pars,diamonds_dat)
  ## transformed parameters
  sigma <- exp(logsigma)
  lp <- logsigma + ## Jacobian
    ## priors
    sum(dnorm(b, 0, 1, log=TRUE)) +
    sum(dt((Intercept-8)/10, 3,log=TRUE) )+ # dropped constants
    sum(dt(sigma/10,df=3, log=TRUE))
  ## likelihood including all constants
  mu_hat <- Intercept + as.numeric(Xc%*%b)
  if (!prior_only)  lp <- lp+sum(dnorm(Y, mean=mu_hat, sd=sigma, log=TRUE))
  ## generated quantities
  REPORT(mu_hat)
  b_Intercept <- Intercept - sum(means_X*b)
  REPORT(b_Intercept)
  REPORT(lp)
  return(-lp) # TMB expects negative log posterior
}

Yeah, I’m not funded so research so I’m just working on C++ right now, but if anyone wants to add the models I made a PR on posterioDB, I can make a google drive which should have all the R files generating data.

Cole, I’m not seeing these scripts on posterior DB.

But forgetting software dev, if your research paper was meant to estimate hierarchical models quickly.

Couldn’t you simulate data: one covariate continuous, and then for model1: binom(N, 1000, .5), then model2, keep same two covariates then add an additional rbinom(N, 1000, .35) and then keep decreasing prob so we have more sparse categorical covariates for each model? I haven’t thought about theory I’m just using intuition, but you can create a multilevel model by x_1[cat1[i]], x_2[cat2[i]]'ing each one of these arbitrary categories, no? So we have covariates with increasingly sparse categories for each column. May be my intuition is wrong?

So the model block, psuedo code might be, x being the random continuous covariate, assume it’s been standardized (-mean / std) so that we’re omitting an intercept…

{
   y ~ rnormal_lpdf(beta_1 * x[cat1[i]] + beta_2 * x[cat2[i]] + ..., blah);
}

So just an arbitrarily generated hierarchical model? Anyone see me?

I’d have to think more if I wanted to do something serious. The simulation I’m simplifying a bit.

I haven’t worked with these R packages.

I put the RTMB code up just to demonstrate what porting a model looks like. It’s not too hard if it only uses basic functionality. I forgot to attach the links for diamonds and kilpisjarvi.

I see what you’re saying. But we don’t need to contrive sparsity like this. It arises naturally in hierarchical models because most random effects have sparse precision matrices, and this can be automatically detected. My main point here is that Stan (and other software) ignores this fact and it can provide really valuable information in a Bayesian analysis, in particularly for preconditioning the posterior.

Below is a table from the MS that shows this. Of this set of models only kilpisjarvi and gp_pois_regr have no sparsity, the former because it’s not hierarchical, and the latter because GPs have dense precision matrices.

Note that both of the examples without sparsity have high correlations (high condition factors) but SNUTS can detect and precondition that prior to sampling and sample hundreds of times more efficiently. So SNUTS is not only good for big, sparse models, but it does seem that is the best use case for it.

Open a PR and we can investigate it more thoroughly? Sure I don’t disagree it’s well vetted that if there is a sparse precision matrix you can capitalize on the sparsity by reducing FLOPS in many applications. I still need to thoroughly read the paper. Did you benchmark this against previous release versions of stan?

Yeah and for the latter I was just saying you can simulate datasets to more robustly, via simulation, test whether exploiting sparsity is, on average more efficient and how beneficial it is. And I don’t doubt it’s not. and then do we have runtime for each model? and moreover these are special cases. So to randomly generate hierarchical models and evaluate computationally explicitly whether the two algorithms improve speed after N trials would be helpful.

And then, max correlation? between two random
variables?

How did you compute correlation, via a posterior generated computationally? Aren’t we looking for a divergence metric? Like, take N(0,1) and then N(3,1), if we randomly sample from each distribution, wouldn’t taking the correlation of random samples be on average the same? I’d have to think. It’s the same thing just mean shifted, no? Are we quantifying this correctly?

moreover there are N parameters. wouldn’t there be variation in divergence between parameters?

and then the plot, “value relative to the original model,”?

And there’s overhead in comparing cmdstanR to different interfaces. I’m on a phone, and I still have to consider what the different metrics you’ve plotted means.

Do you have a branch implemented in C++ for this, or was this done in R? If so, we can open a PR and benchmark against existing Stan HMC adjusting for differences in programming languages. I fee there is too many degrees of freedom and max correlation between two RVs doesnt make sense right now.

The MS does not focus on software at all, since all models are done in TMB (not Stan). Instead it compares standard NUTS (Stan defaults) vs SNUTS. Many of the models in the MS wouldn’t really work in Stan or be way too much work to port to Stan. If you have a hierarchical model in mind and write a simulator in R I will port it to RTMB and we can very easily compare NUTS vs SNUTS in the same software.

To get “max correlation” I take the posterior samples and calculate a correlation matrix (cor(as.data.frame(fit)) , then find the max absolute element of that matrix. I also get it from the approximate precision matrix Q in R as cov2cor(solve(Q)) then again taking the max absolute value across pairwise correlations. The point of that is to show that high correlations are widespread and also easily approximated via Q. Likewise for the ratio of marginal variances. Both of these cause high condition factors but it was nice to break it out like that, I thought. I’m not sure what you mean by “variation in divergence between parameters”… can you explain?

The preconditioning is done in R, but TMB is a C++ backend and that’s where the sparse machinery is handled. If implemented in Stan it would all be internal and in C++ of course.

@aseyboldt Is there a way to take an R function and pass it through to nutpie similar to what StanEstimators does? If so I could add nutpie as an option for SparseNUTS and then we could more easily compare with the novel features (e.g., low rank mass matrix) on existing TMB models.

I don’t think that would be too difficult to implement, but it doesn’t exist yet. I think there might be some multithreading issues, as from what I know the R interpreter is single threaded? But I really don’t know R well at all.
I think nutpieR could implement something like that.

That’s outside my expertise too unfortunately. If it were possible I would be happy to run a bunch of TMB models and compare performance against SNUTS in the same platform.