What's the highest dimensional model Stan can fit using NUTS?

For my work in genomics, I want to fit a model that has about 10,000 parameters, and the posterior is probably going to be multimodal. I am aware that this may not be possible using any out-of-the-box software, so I am currently trying to assess my options and break it down into smaller questions. I combed through the examples on this site, and some seem to regard even 100 parameters as high dimensional. But, I don’t have much of a sense for that yet. So, one of the smaller questions I have identified is: how quickly does Stan’s performance degrade as the dimension increases? And what exactly is the limiting factor that makes higher dimensions hard?

I ran the following code to test Stan on an iid Gaussian distribution of arbitrary dimension. It generates two samples: one from Stan and the other using rnorm. It compares the samples by trying to discriminate between them with a random forest classifier.

Up to D=3, accuracy is about 50%, meaning Stan and rnorm are indistinguishable. At D=4 and 5, accuracy creeps up to 51%. At D=10, it is 53%. At D=100, it’s 58%. At D=1000, it’s 63%, meaning there are regions where Stan’s samples or rnorm's samples are enriched, but those regions affect a low fraction of the data. In each case, sampling is very fast. The random forest is actually taking longer to run!

This is an easy case – unimodal with contours of constant curvature – but it’s still much better performance than I expected. It seems to me there’s nothing intrinsically bad about running HMC in high dimensional parameter spaces, and I should focus on other issues such as multimodality and funnels.

Do you agree with my interpretation? How would you make this experiment more informative?

Here’s my code (with Rstan).

data {
  // Dimension will be provided by an R wrapper.
  int<lower=0> D;
parameters {
  // A big fat iid standard Gaussian
  vector[D] X;
model {
  for (d in 1:D){
    X[d] ~ normal(0, 1);


stanfile = "stan_high_dim.stan"

options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

run_gaussian_benchmark = function(D){
  fitted_mc = rstan::stan(file = stanfile, data = list(D=D), seed = 0)
  samples_mc = as.matrix(fitted_mc)[,-(D+1)]
  samples_genuine = matrix( rnorm(prod(dim(samples_mc))), nrow = nrow(samples_mc), ncol = ncol(samples_mc) )
  dimnames(samples_genuine) = dimnames(samples_mc)
  both = rbind(samples_mc, samples_genuine)
  label = c(
    rep(0, nrow(samples_mc)),
    rep(1, nrow(samples_genuine))
  forest_fitted = randomForest::randomForest(x = both, y = factor(label))
  conf = forest_fitted$confusion/sum(forest_fitted$confusion)
  acc = sum(diag(conf))
  plot(colMeans(samples_mc) - colMeans(both), forest_fitted$importance)
  return(list(forest = forest_fitted, acc = acc))

run_gaussian_benchmark(D = 2)
run_gaussian_benchmark(D = 3)
run_gaussian_benchmark(D = 4)
run_gaussian_benchmark(D = 5)
run_gaussian_benchmark(D = 10)
run_gaussian_benchmark(D = 100)
run_gaussian_benchmark(D = 1000)
1 Like

It has been done with a million parameters (with favorable geometry). I would be much more worried about the multimodality than 10000 nominal parameters, although the MCMC algorithm in Stan can deal with a few modes if they are not too far apart.


As @bgoodri noted the dominant consideration is the geometry of the target density function, not the number of parameters. The cost of a Stan roughly scales as the number of Markov chain Monte Carlo iterations you need to generate sufficient effective sample size times the number of gradient evaluations per iteration times the cost per gradient evaluation.

The cost per gradient evaluation will generally scale with the number of parameters, and there’s nothing you can do about the other than consider parallelization. The map_rect functionality, for example, provides a way of taking advantage of parallel computation resources if they are available.

The number of gradient evaluations per iteration will be determined by the local geometry of the target density function. If your posterior is poorly identified and concentrates around some closed subsurface then HMC will require long trajectories that will require many gradient evaluations. At the same time if you have regions of high curvature then the step size adaptation might leave you with a small step size that will then induce many gradient evaluations even for moderate trajectory lengths.

The number of iterations per effective sample size is more dominated by the global geometry of the target density function. You might be able to explore the local neighborhood around a single mode efficiently, requiring only a few gradient evaluations per iteration, but nothing will help you jump between modes fast enough to provide decent effective sample size per iteration.

For nicer geometries HMC scales really, really well with dimension (ignoring the cost per gradient evaluation). Honestly it would be hard to get a model big enough on a computer where the dimensional scaling become relevant in those nicer circumstances. The problem is when the geometry is multimodal or funnelish – the problem there isn’t the dimensionality so much as the model. Even low-dimensional versions of the model are likely to cause problems.

So yes HMC works really well under nice conditions, but what really matters is identifying when those nice conditions holds. The real power of HMC is the sensitive diagnostics which indicate less than ideal conditions which can then help you identify the design flaws in your model.


Thank you. This is very useful information. I will focus on how to reduce or handle the multimodality.

I’ve been using Stan to sample from a highly multimodal posterior for a model with 11935 parameters. With my 3+ years old laptop, it takes 2 hours for 1000 warmup and 3000 post-warmup iterations per chain. With GPU this could be get down to less than 40 mins. With some optimization of the GPU code (which is now optimized for big n, and not for big number of parameters) it could be made even faster. The sampling passes all convergence diagnostics (using 4 chains) and results are consistent in repeated runs. Looking at the marginals I estimate there is less than 2^9 modes. Here’s a figure showing the multimodality

The model I used is logistic regression with regularized horseshoe prior. I can email you data and code if you want to test.


Wow, that’s really cool. Yes, I would like to try it out. My email is Screen Shot 2020-02-19 at 10.48.55 .

Very cool!
May I also have a peek at your program/data? I promise not to redistribute or copy/paste into my own stuff. Just wanting to learn new tricks.

Program and data are not secret. I was just in a hurry when writing that post.

Data is Prostate_GE listed at http://featureselection.asu.edu/datasets.php

The model is described in Piironen and Vehtari (2017) https://projecteuclid.org/euclid.ejs/1513306866. The code is slightly modified from appendix C.1. There are less transformed parameters saved and bernoulli_logit_glm() is used instead of bernoulli_logit() (this gives about 4x speedup).

bernoulli_logit_glm_rhs.stan code is

data {
  int<lower=0> n;				      // number of observations
  int<lower=0> d;             // number of predictors
  int<lower=0,upper=1> y[n];	// outputs
  matrix[n,d] x;				      // inputs
  real<lower=0> scale_icept;	// prior std for the intercept
  real<lower=0> scale_global;	// scale for the half-t prior for tau
  real<lower=1> nu_global;	  // degrees of freedom for the half-t priors for tau
  real<lower=1> nu_local;		  // degrees of freedom for the half-t priors for lambdas
                              // (nu_local = 1 corresponds to the horseshoe)
  real<lower=0> slab_scale;   // for the regularized horseshoe
  real<lower=0> slab_df;

parameters {
  real beta0;
  vector[d] z;                // for non-centered parameterization
  real <lower=0> tau;         // global shrinkage parameter
  vector <lower=0>[d] lambda; // local shrinkage parameter
  real<lower=0> caux;

transformed parameters {
  vector[d] beta;                     // regression coefficients
    vector[d] lambda_tilde;   // 'truncated' local shrinkage parameter
    real c = slab_scale * sqrt(caux); // slab scale
    lambda_tilde = sqrt( c^2 * square(lambda) ./ (c^2 + tau^2*square(lambda)));
    beta = z .* lambda_tilde*tau;

model {
  // half-t priors for lambdas and tau, and inverse-gamma for c^2
  z ~ std_normal();
  lambda ~ student_t(nu_local, 0, 1);
  tau ~ student_t(nu_global, 0, scale_global*2);
  caux ~ inv_gamma(0.5*slab_df, 0.5*slab_df);
  beta0 ~ normal(0, scale_icept);
  y ~ bernoulli_logit_glm(x, beta0, beta);
generated quantities {
  vector[n] f = beta0 + x*beta;
  vector[n] log_lik;
  for (i in 1:n)
    log_lik[i] = bernoulli_logit_glm_lpmf({y[i]} | [x[i]], beta0, beta);

R code for preparing data and running the model

prostate <- read.csv("prostate.csv", header=FALSE)

n = nrow(x);
p = ncol(x);

## regularized horseshoe prior
p0 <- 5
sigma <- sqrt(1/mean(y)/(1-mean(y)))
sigma <- 1
tau0 <- p0/(p-p0)*sigma/sqrt(n)
## data
data <- list(n = n, d = ncol(x), x = x, y = y,
             nu_local = 1, nu_global = 1, scale_global = tau0,
             scale_icept=5, slab_scale=2, slab_df=100)

fitpb1 <- stan("bernoulli_logit_glm_rhs.stan", data=data, chains=4, iter=3000, warmup=1000,
             control=list(adapt_delta = 0.8, max_treedepth=10), cores=2)

@maxbiostat, I’m curious what was wrong in my post? I don’t see EDIT comment

All I did was add the Stan syntax highlighting. I understand it’s internet etiquette to add an “EDIT” comment but I didn’t want to pollute your post with (what I think is) unnecessary blubber. Happy to start adding that if people would prefer.

Thanks, I didn’t know it was possible to add syntax highlighting and I wouldn’t have learned that it’s possible if I hadn’t asked. Although now I realized that I would have seen your edits by clicking edit, too (not the most logical user interface).

I would prefer that, although now I also know how to check those edits if I post the original post.


Thank you!

what does non centered parameterization mean here? Does it mean beta is not centered around zero?

see https://mc-stan.org/docs/2_23/stan-users-guide/reparameterization-section.html#hierarchical-models-and-the-non-centered-parameterization