Optimizing Stan Performance for Single-Cell RNA-seq Mixed Effects Model (10K features, 50K cells)

Hi everyone,

Many thanks in advance -

I’m looking for advice to improve the speed of my current model.

I have single-cell RNA sequencing data from cancer patients(SAMPLE)with different mutations (FUSION).Each cell has an annotated CELLTYPE.

I have a matrix X that contains gene activity scores across single-cells, and I want to model how different mutations impact gene activities. All the gene activities are centered at zero.

Below is my current model:

mu = alpha + b_fusion + b_celltype
X ~ Normal(mu, sigma)
  • I am modeling each patient as a random effect (varying intercepts alpha).
  • Then I estimate b_fusion while controlling for cell types.
  • There are ~10,000 FEATURES, so I am essentially fitting 10K linear models each with 50K (number of cells) observations.

The current model with data downsampled to 5 features and 2,000 cells take ~5 minutes to fit with num_warmup=500; num_samples=2000. But the entire dataset takes 2 days to draw ~100 samples.

I’ve implemented the following to improve speed:

  • Prior predictive check to make my priors more narrow

  • Non-centered parameterization for alpha

  • Pre compute mu in the transformed parameter block instead of looping over cells during likelihood calculation

But it’s still taking way too long. My final dataset will likely grow another ~5 fold.

I am doing this on the HPC; memory and compute power is not rate limiting.

data {
  int<lower=0> NCELL;
  int<lower=0> NFEAT;
  int<lower=0> NSAMPLE;
  int<lower=0> NFUSION;
  int<lower=0> NCELLTYPE;
  
  matrix[NFEAT, NCELL] X; //activity matrix
  array[NCELL] int<lower=1> SAMPLE; 
  array[NCELL] int<lower=1> FUSION; 
  array[NCELL] int<lower=1> CELLTYPE;
}

parameters {
  matrix[NFEAT, NFUSION] b_fusion;
  matrix[NFEAT, NCELLTYPE] b_celltype;
  
  // Non-centered random effects
  matrix[NFEAT, NSAMPLE] alpha_std;
  real alpha0; // Mean of the random effects
  real<lower=0> sigma_alpha; // SD of the random effects
  
  // residual sd
  vector<lower=0>[NFEAT] sigma;
}

transformed parameters {
  // Center and scale the random effects
  matrix[NFEAT, NSAMPLE] alpha;
  for (i in 1:NFEAT) {
    alpha[i, ] = alpha0 + alpha_std[i, ] * sigma_alpha;
  }
  
  // Pre-calculate the mean matrix for the likelihood
  matrix[NFEAT, NCELL] mu;
  for (i in 1:NFEAT) {
    for (j in 1:NCELL) {
      mu[i, j] = alpha[i, SAMPLE[j]] + b_fusion[i, FUSION[j]] + b_celltype[i, CELLTYPE[j]];
    }
  }
}

model {
  alpha0 ~ normal(0, .4);
  sigma_alpha ~ lognormal(-1.5, 1);
   // Prior for standardized random effects
  to_vector(alpha_std) ~ normal(0, 1);
  
  sigma ~ gamma(.1, .1);
  to_vector(b_fusion) ~ normal(0, .4);
  to_vector(b_celltype) ~ normal(0, .4);
  
  for(i in 1:NFEAT){
    X[i, ] ~ normal(mu[i, ], sigma[i]);
  }
}

  1. Check out the ways to speed-up with C++ compiler and stanc options Options for improving Stan sampling speed – The Stan Blog
  2. Compute transformed parameters in the model block to save io time not writing them to file
  3. You can avoid couple loops, but it’s possible that this doesn’t affect the speed
  // Center and scale the random effects
  matrix[NFEAT, NSAMPLE] alpha = alpha0 + alpha_std * sigma_alpha;
  
  // Pre-calculate the mean matrix for the likelihood
  matrix[NFEAT, NCELL] mu;
  for (j in 1:NCELL) {
    mu[, j] = alpha[, SAMPLE[j]] + b_fusion[, FUSION[j]] + b_celltype[, CELLTYPE[j]];
  }
  1. Use Reduce-sum parallelization

I don’t see much else to do, but I know @Bob_Carpenter has run MCMC successfully with similar size models, so he might have further ideas

Note that when combined with tip 1, I would expect this to be faster, as the stanc compiler is able to better optimize (several parameters end up as the special var-matrix type, which doesn’t happen in the way it is coded in the original post)

  1. It might be possible that that normal density evaluation ~ normal() might get faster with GPU. See Running Stan on the GPU with OpenCL • cmdstanr and Stan Math Library: OpenCL CPU/GPU Support

In addition to what’s already been said, I’d also have a careful look at the parametrization.

If you have a bad parametrization you will see that the number of leapfrog steps per draw blows up (probably to 500-1000) if you increase the datasize. If you can fix those issues, you can often see a drop in computation time by a large factor.

So for instance:

Non-identifiability can quickly become a much bigger problem if you have a large dataset: if there is a subspace where the posterior variance is always equal to the prior variance, but everywhere else the posterior variance goes to zero, then the condition number of the posterior covariance will go to infinity. See for instance The Sum-to-Zero Constraint in Stan for an example of how to remove extra degrees of freedom in an example like yours.

I’d also check what the effect of the non-centered parametrization is. Often, the centered parametrization works much better if the population standard deviation can be estimated very precisely, and with large datasets that is quite common.

Purely on the time-per-gradient basis, I’d also experiment a bit with the order of matrices you are working with. I think stan stores matrices in column-major format, so extracting a row of a matrix should be slower than extracting a column, and iterating such that the first index varies faster should also be faster than the other way round. I don’t know if that effect is drowned out by loop overhead in stan though…

1 Like

My first thought is that transposing X and mu to be matrix[NCELL, NFEAT] would help, as your likelihood would then be in column-major order, which should be faster. Nesting rows within columns while building mu would also help with speed.

Ultimately, reduce_sum() would be really helpful for increasing your speed. My intuition is that building mu within the function would probably be more efficient that building it outside and passing it along.

This model should also be amenable to fitting in brms, which can automatically set up reduce_sum() for you. You would probably need to flatten X into a vector for that to work, but that shouldn’t be too much of a problem. I’d try switching your rows and columns first, though.

2 Likes

But I didn’t have 10K linear regressions inside! It was just a pure random effects model.

This is just a lot of arithmetic no matter how you organize it. But it does mean you can factor this all into multiple models and fit subsets of the data. This should make it very easy to parallelize if you’re willing to forego the hierarchical component. In fact, the strategy that @andrewgelman often recommends here is to just fix the variance to fit—you might get good enough results without the hierarchical model, then you can look at the posterior variance of the parameters and find a reasonable prior for them in an expanded model.

The transformed parameters can be coded without explicit loops, going one step further than @avehtari went:

matrix[NFEAT, NCELL] mu
  = alpha[ , SAMPLE] + b_fusion[ , FUSION] + b_celltype[ , CELLTYPE];

These should both be faster. Stan has to allocate the intermediate matrices, but then the resulting matrix arithmetic and autodiff should more than make up for it.

You also want to move those transformed parameters into the model block to avoid having to save them.

I’m not convinced that parallelization would help a lot here as there’s very little arithmetic in this model compared to the number of parameters. The parallelization is usually more efficient when there is much more compute being done than there are parameters being passed, like in an ODE solver.

You should definitely do this.

You want the column index to vary faster, which is usually the second index.

It is much slower for large matrices.

This will also help.

What was the motivation for using a gamma prior rather than inverse gamma prior for sigma? We use the shape/rate (rate is inverse scale) parameterization, so the gamma(0.1, 0.1) prior has a mean of 1 and a variance of 10, so a scale of \sqrt{10}. And the mode’s at zero, which can be problematic if you find a lot of values clustering near there as they get sent to negative infinity in the unconstrained parameterization over which Stan samples.

If sigma_alpha isn’t sampling well, turning down the variance on that can also help.

The one thing I’m going to recommend that other people haven’t recommended is doing a simple maximum likelihood fit without the hierarchical priors to initialize all the coefficients. This can make a huge difference as Stan’s slow to initialize. A better adaptation scheme can be found in Adrian Seyboldt’s nutpie package, which can sample Stan models: nutpie · PyPI
Using both nutpie and a warm init should go a long way to at least having things start faster, which is where a huge amount of the time you’re seeing will be going—into those first 100 or so iterations. We’re working to revise the way Stan does warmup along the lines of nutpie.

3 Likes

1. Yes, I often recommend pinning the group-level variance parameter or covariance matrix to a pre-chosen value based on subject-matter information. Often the inference isn’t super-sensitive to this group-level variance, as long as it’s not so small that it causes all the estimates to disappear to zero and not so small that the estimates are wildly noisy.

I’ve toyed with the idea of making this a more formal procedure, for example drawing 10 values of the set of variance parameters from a prior, then using these to run 10 fast inferences (could be MCMC or even just plain old optimization and Laplace approx), then averaging over them using stacking. I think this could work, but I’ve never actually tried it, let alone evaluated the idea. It’s a research idea!

2. Sometimes we do use gamma priors for group-level variance parameters. The gamma prior with 1 or more degrees of freedom has the pleasant property of being zero-avoiding, which is especially helpful when doing marginal maximum likelihood, as we discuss in our 2013 paper: https://sites.stat.columbia.edu/gelman/research/published/chung_etal_Pmetrika2013.pdf or for covariance matrices (using the Wishart, _not_ inverse-Wishart) prior for cov matrix in our 2014 paper: https://sites.stat.columbia.edu/gelman/research/published/chung_cov_matrices.pdf

3. Another thing that’s worked well for me is to use Pathfinder to get starting values. It varies, but sometimes Pathfinder runs very fast and then we can jointly estimate all the parameters and not worry so much about the funnel.

3 Likes

Hi all,

Apologize for the delay and thanks everyone for the useful input.

I gave up on the hierarchical component in the end; it was impossible to fit without downsampling.

Using reduced_sum to slice by features sped things up a lot.

Thanks again for helping out a biologist - these kinds of things don’t come naturally to me!

1 Like