Stan_lmer really slow

Hi everyone,

So I’m having an issue with stan_lmer. I’m trying to run a model with about 200k observations and around 435 random intercepts. The exact model is of the form

stan_lmer(y ~ X + (1|id))

The model doesn’t progress after hours. I’ve tried running the model on a subset of the data. If I choose samples of 1,000, 5,000, or 10,000, it runs pretty quickly. Once I get above 10,000, it chugs.

I’ve tried centering the variables - no dice.

Any advice?

Morning,

Can you post your full model with the number of chains, iters, and such? Thanks.

1 Like

For something that simple yet with such large data, you probably want to use cmdstanr and write the model directly:

data{
 int nXY; //number of observations in X/Y
 vector[nXY] Y ; // Y (outcome) observations
 vector[nXY] X ; // X (covariate) observations
 int nZ ; // number of levels in Z
 int<lower=1,upper=nZ> Z[nXY] ; // Z-label for each outcome in Y
  // centered: binary toggle for intercept centered/non-centered parameterization
  int<lower=0,upper=1> centered ;
}
parameters{
  real<lower=0> noise ; //measurement noise scale
  real betaX ; // effect of the covariate X
  real mu ; // intercept
  real<lower=0> sigma ;  // scale of variability among levels of Z
  vector<
    offset = (centered ? 0 : mu)
    , multiplier = (centered ? 1 : sigma)
  >[nZ] Zval ; // value for each level of Z
}
model{
  // Priors
  noise ~ weibull(2,1) ; //must be changed to reflect domain expertise 
  betaX ~ std_normal() ; //must be changed to reflect domain expertise 
  mu ~ std_normal() ; //must be changed to reflect domain expertise
  sigma ~ std_normal() ; //must be changed to reflect domain expertise
  //hierarchical prior for Zval:
  zVal ~ normal( mu, sigma ) ;
  //likelihood (can be sped up via reduce_sum if lots of cores available!)
  Y ~ normal( zVal[Z] + betaX*X, noise ) ;
}

As noted in the final comment, if you have lots of CPU cores available, you can parallelize the final likelihood computation using the reduce_sum framework.

1 Like

For that much data with that simple a model, do you really need to do MCMC? Does using vb give a reasonable answer?

1 Like

At the moment brms is likely to scale to larger datasets a little better than rstanarm, especially if you use the cmdstanr backend.

Additionally, with the cmdstanr backend you can use within-chain parallelisation via the threads argument. Also, if you have a discrete GPU in your system you can use this to accelerate your models

2 Likes

I have a very similar situation to OP.

I haven’t noticed much difference between brms with cmdstanr vs rstan, but I have found using threads=threading(2) tends to slow down my brm(y ~ a + b + c + (1|id), chains=4, cores=4) sampling by a factor of 2-3 on a 12-physical-core Ryzen.

Is this thread the best reference on using GPU parallelism?

What is the main difference between writing that model manually and using brms to generate one, such that manually writing it would make it perform faster?

Nothing directly, but you have more control (ex. for the OPs data, they probably want a “centered” parameterized model, and I’m not sure that brms makes it easy to switch from non-centered, which I presume is its default). I’m probably just biased by knowing raw Stan well and brms not-well.

Note that if you use raw stan, you can do things like this and this, whereas I’m not sure that brms includes those performance tweaks.

2 Likes

Are people actually using MCMC on models like

y ~ a + b + c + (1|id)

with n>10k (or >100k) in practice (coded directly in Stan or otherwise)? It seems like it would take days or weeks, at least using the techniques I’ve tried so far (namely cores=4, threads=threading(2), and noncentered parametrization).

That’s about the scale of a dataset I’m currently working on, except I’m doing:

x ~ (a+b+c+d+e+f+g)^3 + (1+(a+b+c+d+e+f+g)^3  | q | id)
y ~ (a+b+c+d+e+f+g)^3 + (1+(a+b+c+d+e+f+g)^3  | q | id)
z ~ (a+b+c+d+e+f+g)^3 + (1+(a+b+c+d+e+f+g)^3  | q | id)

But aside from over-regularization of the correlations that I’m still working on, it at least fits in a few hours. Mind you, all my variables are dichotomous and thereby benefit greatly from the reduced redundant computation trick I linked above.

Oh, and I guess it helps that while the raw number of observations is in the 30k+ range, there are only 60 levels of id.

1 Like

See also here.