Trying to make a "three-level nested linear model" run faster

I am trying to fit a 3 level hierarchical random intercept model.
To understand how to tame such a beast I looted this code by Kazuki Yoshida.

So i tried his un-vectorized code (sorry no simulated data, but real data here)

classroom <- read.csv("http://www-personal.umich.edu/~bwest/classroom.csv")
schoolLookupVec <- unique(classroom[c("classid","schoolid")])[,"schoolid"]
dat <- with(classroom,
            list(Ni           = length(unique(childid)),
                 Nj           = length(unique(classid)),
                 Nk           = length(unique(schoolid)),
                 classid      = classid,
                 schoolid     = schoolid,
                 schoolLookup = schoolLookupVec,
                 mathgain     = mathgain))
resStan <- stan(model_code = stan_code, data = dat, chains = 4, iter = 10000, warmup = 1000, thin = 10)

with the following stan code:

data {
  // Define variables in data
  // Number of level-1 observations (an integer)
  int<lower=0> Ni;
  // Number of level-2 clusters
  int<lower=0> Nj;
  // Number of level-3 clusters
  int<lower=0> Nk;

  // Cluster IDs
  int<lower=1> classid[Ni];
  int<lower=1> schoolid[Ni];

  // Level 3 look up vector for level 2
  int<lower=1> schoolLookup[Nj];

  // Continuous outcome
  real mathgain[Ni];
  
  // Continuous predictor
  // real X_1ijk[Ni];
}

parameters {
  // Define parameters to estimate
  // Population intercept (a real number)
  real beta_0;
  // Population slope
  // real beta_1;

  // Level-1 errors
  real<lower=0> sigma_e0;

  // Level-2 random effect
  real u_0jk[Nj];
  real<lower=0> sigma_u0jk;

  // Level-3 random effect
  real u_0k[Nk];
  real<lower=0> sigma_u0k;
}

transformed parameters  {
  // Varying intercepts
  real beta_0jk[Nj];
  real beta_0k[Nk];

  // Individual mean
  real mu[Ni];

  // Varying intercepts definition
  // Level-3 (10 level-3 random intercepts)
  for (k in 1:Nk) {
    beta_0k[k] <- beta_0 + u_0k[k];
  }
  // Level-2 (100 level-2 random intercepts)
  for (j in 1:Nj) {
    beta_0jk[j] <- beta_0k[schoolLookup[j]] + u_0jk[j];
  }
  // Individual mean
  for (i in 1:Ni) {
    mu[i] <- beta_0jk[classid[i]];
  }
}

model {
  // Prior part of Bayesian inference
  // Flat prior for mu (no need to specify if non-informative)

  // Random effects distribution
  u_0k  ~ normal(0, sigma_u0k);
  u_0jk ~ normal(0, sigma_u0jk);

  // Likelihood part of Bayesian inference
  // Outcome model N(mu, sigma^2) (use SD rather than Var)
  for (i in 1:Ni) {
    mathgain[i] ~ normal(mu[i], sigma_e0);
  }
}

This model takes 574.467 seconds (Total) to run with all 4 chains where the estimated Bayesian Fraction of Missing Information was low. It is impossible to use the pairs() command to identify problems. And:

print(fit, pars = c("beta_0","sigma_e0","sigma_u0jk","sigma_u0k"))
            mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
beta_0     57.40    0.02 1.47 54.57 56.43 57.42 58.37 60.35  3494 1.00
sigma_e0   32.14    0.02 0.76 30.67 31.62 32.14 32.65 33.66  1854 1.00
sigma_u0jk  9.99    0.11 2.31  4.90  8.61 10.10 11.52 14.20   452 1.01
sigma_u0k   8.62    0.07 2.10  3.97  7.33  8.76 10.01 12.46   821 1.00


I tried a vectorized version of the same model (hoping in speeding up a little the model)

data {
  // Define variables in data
  // Number of level-1 observations (an integer)
  int<lower=0> Ni;
  // Number of level-2 clusters
  int<lower=0> Nj;
  // Number of level-3 clusters
  int<lower=0> Nk;

  // Cluster IDs
  int<lower=1, upper=Nj> classid[Ni];
  int<lower=1, upper=Nk> schoolid[Ni];

  // Level 3 look up vector for level 2
  int<lower=1, upper=Nk> schoolLookup[Nj];

  // Continuous outcome
  vector[Ni] mathgain;
}

parameters {
  // Population intercept
  real beta_0;

  // Level-1 errors
  real<lower=0> sigma_e0;

  // Level-2 random effect
  vector[Nj] u_0jk;
  real<lower=0> sigma_u0jk;

  // Level-3 random effect
  vector[Nk] u_0k;
  real<lower=0> sigma_u0k;
}

transformed parameters  {

  // Varying intercepts
  vector[Nj] beta_0jk;
  vector[Nk] beta_0k;

  // Varying intercepts definition
  // Level-3
  beta_0k = beta_0 + u_0k;
  
  // Level-2
  beta_0jk = beta_0k[schoolLookup] + u_0jk;
}

model {
  // Prior part of Bayesian inference

  // Random effects distribution
  u_0k  ~ normal(0, sigma_u0k);
  u_0jk ~ normal(0, sigma_u0jk);

  // Likelihood part of Bayesian inference
   mathgain ~ normal(beta_0jk[classid], sigma_e0);
}

Fitted values are similar (with better Rhat), but this vectorized model takes longer to sample (more or less it takes alway double of the unvectorized code:837.244 seconds (Total) while I expected an improvement) with the same warnings on low estimated Bayesian Fraction of Missing Information (which still I am not able to explore )

            mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
beta_0     57.37    0.02 1.44 54.56 56.42 57.35 58.35 60.14  3571    1
sigma_e0   32.15    0.02 0.78 30.68 31.60 32.14 32.68 33.67  1389    1
sigma_u0jk  9.95    0.13 2.47  4.56  8.37 10.15 11.62 14.38   354    1
sigma_u0k   8.47    0.09 2.19  3.45  7.19  8.60  9.89 12.54   593    1

Any advice to make my model run faster?
The QR reparametrization could be applied to this kind of models? Or should I try something else like adding some weakly informative priors on parameters?

Have u tried the non-centered parametrization? This likely very relevant for such a model. So in your formulation you replace

beta_0k = beta_0 + u_0k;

with

beta_0k = beta_0 + u_0k * sigma_u0k;

and then place a normal(0,1) prior on u_0k; same for the others. The manual has a ton of material on this.

And yes! Please add weakly informative priors.

3 Likes

Thanks @wds15!
I have read of non-centerd reparametrization in manual and in the forum (many users asking unceasingly about it).
Still it is -for me- a little “obscure”. I will study more about it and trying it. I will try it but i I still feel it like “a hand grenade in my pocket”.

Well, you are just reparametrizing the very same density. In a nutshell it works so well because:

  • NUTS “likes” to sample unit normals

  • hierarchical models are typically used in data sparse situations to borrow the little information there is on the data across the units (its a data driven prior)

  • sparse data loosely means that the impact of your data is weak on the prior. So the prior does not change much when compared to the posterior.

  • so if you represent your model in a way such that your prior is essentially a bunch of N(0,1), then that won’t change much for sparse data and hence NUTS runs insanely better.

This is the poor mans explanation. The more involved explanation makes the argument of easier to sample from geometry in the posterior (funnel geometry). @betanalpha has good papers on this which are referenced to in the manual.

Don’t be shy with non-centered! It works quite amazing.

Looking at the sigma estimates which you are getting makes me think you could rescale the data such that the sigma estimates land on a unit scale (the betas as well).

4 Likes

Thanks for the ‘poor man’ explanation! I started to read Betancourt & Girolami paper but your explanation helps me to procede with my work.

From what I read, the non-centered parametrization works well only when with the parameters.

With the data it is better to use a centered parametrization. So, in my case:

mathgain ~ normal(beta_0jk[classid], sigma_e0);

will be rescaled to have unit scale parameters (and put good priors on them)

transformed data {
   vector[Ni] Y_ijk;
   Y_ijk = (mathgain- mean(mathgain))/sd(mathgain);
}

model {
  ...
  beta_0jk ~ normal(0, 1);
  sigma_e0 ~ normal(0, 1);
  Y_ijk ~ normal(beta_0jk[classid], sigma_e0);
}

have I understood well?

I think that’s correct, yes.

So, with the help of @wds15, I post here the vectorized code of the model, with the non-centered parametrization.
I hope someone else will find it useful in future. Many thanks to Kazuki Yoshida that initially shared his code.

data {
  // Define variables in data
  // Number of level-1 observations
  int<lower=0> Ni;
  // Number of level-2 clusters
  int<lower=0> Nj;
  // Number of level-3 clusters
  int<lower=0> Nk;

  // Cluster IDs
  int<lower=1, upper=Nj> classid[Ni];
  int<lower=1, upper=Nk> schoolid[Ni];

  // Level 3 look up vector for level 2
  int<lower=1, upper=Nk> schoolLookup[Nj];

  // Continuous outcome
  vector[Ni] mathgain;
}

transformed data {
   vector[Ni] Y_ijk;
   Y_ijk = (mathgain - mean(mathgain))/sd(mathgain);
}

parameters {
  // Population intercept
  real beta_0;

  // Level-1 errors
  real<lower=0> sigma_e0;

  // Level-2 random effect
  vector[Nj] u_0jk;
  real<lower=0> sigma_u0jk;

  // Level-3 random effect
  vector[Nk] u_0k;
  real<lower=0> sigma_u0k;
}

transformed parameters  {

  // Varying intercepts
  vector[Nj] beta_0jk;
  vector[Nk] beta_0k;

  // Varying intercepts definition
  // Level-3 with centered parametrization
  beta_0k = beta_0 + u_0k * sigma_u0k;
  
  // Level-2
  beta_0jk = beta_0k[schoolLookup] + u_0jk * sigma_u0jk;
}

model {
  // Prior part of Bayesian inference
  
  beta_0 ~ normal(0, 1);
  sigma_e0 ~ normal(0, 1);
  
  sigma_u0k ~ normal(0,1);
  sigma_u0jk ~ normal(0,1);
  
  
  // Random effects distribution
  u_0k  ~ normal(0, 1);
  u_0jk ~ normal(0, 1);
    
  // Likelihood 
   Y_ijk ~ normal(beta_0jk[classid], sigma_e0);
}

It runs in 281.321 seconds (Total) which is a great improvement form the initial version, with higher effective sample size for the estimated model parameters.

4 chains, each with iter=10000; warmup=1000; thin=10;                                                                                                                                                                                      
post-warmup draws per chain=900, total post-warmup draws=3600.                                                                                                                                                                             
                                                                                                                                                                                                                                           
           mean se_mean   sd  2.5%   25%  50%  75% 97.5% n_eff Rhat                                                                                                                                                                        
beta_0     0.00       0 0.04 -0.09 -0.03 0.00 0.02  0.08  3573    1                                                                                                                                                                        
sigma_e0   0.93       0 0.02  0.89  0.91 0.93 0.94  0.97  3436    1
sigma_u0jk 0.28       0 0.07  0.14  0.24 0.29 0.33  0.41  2456    1
sigma_u0k  0.25       0 0.06  0.11  0.21 0.25 0.29  0.36  2528    1
7 Likes

More technically centering or non-centered affects the joint posterior geometry only for latent parameters. For terms that involve observed data non-centered can affect the marginally posteriors but the effect if very different and typically not as impactful.

1 Like

Thank you @betanalpha. I am reading the paper you co-authored with Girolami just right now!

You might also find some of the discussion in https://betanalpha.github.io/assets/case_studies/divergences_and_bias.html useful as a compliment to the paper.

1 Like

@betanalpha WOW. I totally missed “Diagnosing Biased Inference with Divergences”. And actually it is available in the stan’s Case Studies. Adding this to the help that @wds15 gave me… this is pure gold!.

This post has been really helpful for me as well. One comment on the latest stan model- shouldn’t the sigma’s have a half cauchy or exponential distribution? It is not possible to have a negative SD, so a standard normal prior doesn’t make a lot of sense.

The variable declarations for the sigmas include the lower bound of zero, so those normal priors get implemented as half-normal automatically.

2 Likes