Speed up multilevel logit: large dataset and group-level variables

Hi all,

I’m writing with a question about fitting a multilevel logit model to a dataset with ~300,000 observations. The data is cross-national survey data where individual respondents are nested within countries and years. I’m including group-level covariates to predict country and year varying intercepts, like the county-level uranium variable in the radon example or state-level variables in Gelman et al’s Red State Blue State project.

Even after using reduce_sum (huge thanks to the Stan developers for this), this model takes about 7 days on a computer with 8 cores and 32GB of RAM. Looking at the code below, is there any way I could speed this up? Would writing the model differently by putting the group-level variables straight into the individual level regression in partial_sum speed things up? The non-centered parameterization of both sets of intercepts is necessary, as far as I can tell.


Can write the model as

y \sim Bernoulli Logit(\alpha + \alpha_{state} + \alpha_{year} + \textbf{X}\beta)

where in two group level regressions that include a constant in the data matrices G and Z:

\alpha_{year} \sim N(\textbf{G}\lambda, \sigma_{year})
\alpha_{state} \sim N(\textbf{Z}\gamma, \sigma_{state})

Stan program here. Country intercepts are indexed with S, while year intercepts are indexed with T. The group-level regressions are in the transformed parameters block.


functions {
  real partial_sum(int[] y_slice,
                   int start, int end,
                   matrix X, vector beta,
                   int[] state,
                   vector alpha_state,
                   int[] year,
                   vector alpha_year,
                   real alpha) {
    return bernoulli_logit_lpmf(y_slice | alpha +
      alpha_state[state[start:end]] + alpha_year[year[start:end]] +
      X[start:end, ] * beta);
  }
}

// data
data {
  int<lower = 0> N; // number of individual obs
  int y[N]; // outcome
  
  int<lower = 0> S; // number of state-year obs
  int<lower = 1, upper = S> state[N]; // state indicator

  int<lower = 0> T; // number of year obs
  int<lower = 1, upper = T> year[N];
  
  int<lower = 1> I; // number of individual variables
  matrix[N, I] X; // individual-level reg matrix
  int<lower = 1> J; // number of state-year groups
  matrix[S, J] Z; // state-year reg matrix
  int<lower = 1> L; // number of system-year groups
  matrix[T, L] G; // year/system reg matrix
}

// parameters.
parameters {
  real alpha; // overall intercept
//  real<lower = 0> sigma; // outcome variance
 
vector[S] alpha_state_std; // state intercepts- noncenter
real<lower = 0> sigma_state; // state var. hyperparam

vector[T] alpha_year_std; // state intercepts- noncenter
real<lower = 0> sigma_year; // state var. hyperparam

vector[I] beta; // individual coefs
vector[J] gamma; // state-year coefs
vector[L] lambda; // alliance-year coefs

}

transformed parameters {
   vector[T] mu_year; 
   vector[T] alpha_year; // year intercepts
   vector[S] alpha_state; // state intercepts
   vector[S] mu_state;
   
   // regression models of state and year varying intercepts
   mu_year = G * lambda; // year/system level 
   mu_state = Z * gamma; // state-year level

  // non-centered parameterization of state intercepts
for(s in 1:S)
  alpha_state[s] = mu_state[s] + sigma_state * alpha_state_std[s];

  // non-centered parameterization of year intercepts
for(t in 1:T)
  alpha_year[t] = mu_year[t] + sigma_year * alpha_year_std[t];


}

// model 
model {
  // define grain size (automatic selection)
  int grainsize = 1; 
  
  // define priors
  alpha ~ std_normal(); 
  
  // state parameters
  sigma_state ~ normal(0, 1); // half-normal
  alpha_state_std ~ std_normal(); // state intercepts non-centered
  // year paramters
  alpha_year_std ~ std_normal(); // year intercepts non-centered
  sigma_year ~ normal(0, 1); // half-normal
  
  // regression coef priors 
  // robust priors for indiv params and 
  //weak info normal for state-year level
  beta ~ student_t(7, 0, 1); 
  gamma ~ normal(0, 1); 
  lambda ~ normal(0, 1);
  
 // split outcome w/ reduce sum  
target += reduce_sum(partial_sum, y, 
                      grainsize,
                      X, beta,
                      state, alpha_state, 
                      year, alpha_year,
                      alpha);
}

bernoulli_logit_glm should speed up the X * beta part of the model, though it won’t speed up the random effects (12.3 Bernoulli-Logit Generalized Linear Model (Logistic Regression) | Stan Functions Reference).

You can vectorize these:

for(s in 1:S)
  alpha_state[s] = mu_state[s] + sigma_state * alpha_state_std[s];

  // non-centered parameterization of year intercepts
for(t in 1:T)
  alpha_year[t] = mu_year[t] + sigma_year * alpha_year_std[t];

like this:

alpha_state = mu_state + sigma_state * alpha_state_std;
alpha_year = mu_year + sigma_year * alpha_year_std;

It might be a tiny bump but I wouldn’t expect much from that.