How to Improve Convergence in Random-Slopes Hierarchical Model

I am having trouble with my random-slopes hierarchical model and would appreciate some pointers how to get my chains to converge. Because I am a masochist I am doing it in Stan rather than a shell like brms or rstanarm.

This model is the equivalent of a random slopes hierarchical regression, run on a panel design testing for group differences in linear trend over time. The outcome score is predicted by a three-level between-Ss predictor, a ten-measurent within-Ss predictor, a term representing the interaction between between- and within-Ss predictors, a subject-level term (the random intercept), and the interaction between the within-Ss and subject-level predictor (the random slope).

Here is the toy data. There are three groups of 40 participants, each participant with 10 measurements. Group A has intercept 50 and slope 3, B has intercept 100 and slope 0, C has intercept 150 and slope -3. Noise has been built into all levels.

set.seed(1234)

# create dataset. 
aInt <- 50
bInt <- 100
cInt <- 150
aSlope <- 3
bSlope <- 0
cSlope <- -3

# create data frame
df <- data.frame(id = factor(rep(1:120, each = 10, length.out = 1200)),
                 group = factor(rep(LETTERS[1:3], each = 400, length.out = 1200)),
                 time = rep(1:10, length.out = 1200),
                 intercept = rep(rnorm(120, rep(c(aInt, bInt, cInt), each = 40), 10), each = 10),
                 slope = rep(rnorm(120, rep(c(aSlope, bSlope, cSlope), each = 40), .1), each = 10)) %>% 
                 mutate(score = intercept + slope*time + rnorm(1,0,1)) %>%
                 dplyr::select(id, group, time, score)

# graph it
ggplot(df, aes(x = time, y = score, group = group, colour = group, linetype = group)) +
       geom_point(shape = 1) +
       geom_smooth(method = "lm", se = F, colour = "black") +
       scale_x_continuous(breaks = 1:10) +
       scale_y_continuous(breaks = seq(50,150,50))

Rplot

Now to run the model and generate the chains. Note the likelihood has a grand intercept (a), a term for the categorical group predictor (bGroup), a term for the continuous time predictor bTime, the group x time interaction term bGxT the random intercept term bSubj and the random slope tern bSxT. It has hyper-priors on the variance terms in the priors for all the predictors, except the overall noise term sigma, which is a normal distributed truncated at 0 with a standard deviation five times the overall standard deviation. The prior on the intercept term a is a normal distribution centred on the grand mean score and with variance five times the overall standard deviation.

### Step 1: data list. 

dList <- list(N = nrow(df),
              nSubj = nlevels(df$id),
              nGroup = nlevels(df$group),
              sIndex = as.integer(df$id),
              gIndex = as.integer(df$group),
              time = df$time,
              score = df$score,
              gMean = mean(df$score),
              gSD = sd(df$score))

### Step 2: make model

write("
      data{
      int<lower=1> N;
      int<lower=1> nSubj;
      int<lower=1> nGroup;
      int<lower=1,upper=nSubj> sIndex[N];
      int<lower=1,upper=nGroup> gIndex[N];
      real time[N];
      real score[N];
      real gMean;
      real gSD;
      }

      parameters{
      real a;
      real bTime;
      vector[nGroup] bGroup;
      vector[nGroup] bGxT;
      vector[nSubj] bSubj;
      vector[nSubj] bSxT;
      real<lower=0> sigma;
      real<lower=0> sigma_g;
      real<lower=0> sigma_t;
      real<lower=0> sigma_gt;
      real<lower=0> sigma_s;
      real<lower=0> sigma_st;
      }
      
      model{
      vector[N] mu;

      // hyper-priors
      sigma_g ~ normal(0,5);
      sigma_t ~ normal(0,5);
      sigma_gt ~ normal(0,5);
      sigma_s ~ normal(0,5);
      sigma_st ~ normal(0,5);

      // priors
      sigma ~ normal(0,5*gSD);
      a ~ normal(gMean, 2*gSD);
      bGroup ~ normal(0, sigma_g);
      bTime ~ normal(0, sigma_t);
      for (g in 1:nGroup) {               // prior on Group x Time interaction
      bGxT[g] ~ normal(0, sigma_gt);
      }
      bSubj ~ normal(0, sigma_s);
      for (s in 1:nSubj) {
      bSxT[s] ~ normal(0, sigma_st);     // prior on Subject x Time interaction
      }

      //likelihood
      for (i in 1:N) {
      mu[i] = a + bGroup[gIndex[i]] + bTime*time[i] + bGxT[gIndex[i]]*time[i] + bSubj[sIndex[i]] + bSxT[sIndex[i]]*time[i];
      }
      score ~ normal(mu, sigma);
      }
      ", file = "temp.stan")


# Step 3: generate chains
trendMod <- stan(file = "temp.stan",
                 data = dList,
                 warmup = 1e3,
                 iter = 2e3,
                 cores = 1,
                 chains = 1)

Now when I do some diagnostics

print(trendMod, pars = c("a", "bGroup", "bTime", "bGxT", "sigma"), probs = c(0.025, 0.975))

The output shows a model with severe convergence problems


#             mean se_mean sd   2.5%  97.5% n_eff Rhat
# a          69.31       0  0  69.31  69.31     6 1.00
# bGroup[1] -23.43       0  0 -23.43 -23.43     5 1.32
# bGroup[2]  28.87       0  0  28.87  28.87     6 1.19
# bGroup[3]  81.07       0  0  81.07  81.07     5 1.04
# bTime       0.00       0  0   0.00   0.00    16 1.16
# bGxT[1]     3.02       0  0   3.02   3.02    15 1.05
# bGxT[2]     0.01       0  0   0.01   0.01    16 1.00
# bGxT[3]    -2.96       0  0  -2.96  -2.96     7 1.45
# sigma       0.00       0  0   0.00   0.00     3 2.39

What am I doing wrong?

I’ve tried different hyper-priors, longer warmups etc but I get the same terminally low n_eff and Rhat. Perhaps my model is wrong? Or perhaps the data is too uniform?

The strange thing is that, despite the terrible diagnostics, the model seems to be retrieving the correct parameter estimates for the group, time, and interaction terms. The sigma estimate of 0 is a worry though.

Any help much appreciated.

These are just some quick guesses. The very low n_eff’s could be an indication that the model is not fully identified. From the means, I guess that a and bGroup are not uniquely identified. The cause here could be that the prior on a is too narrow but I am not sure. You could also try to force bGroup to have a mean of zero by setting bGroup[3] = - bGroup[1] - bGroup[2]

Thank you for the advice. I recall Kruschke’s code in JAGS for factorial designs performing these sum-to-zero transformations of parameters at the end of each script. I will look into how to do this in Stan.

I tried some other things (getting rid of the grand mean etc) but I couldn’t really resolve the identification problem. The other piece of advice that I can give is to center time as time - 5.5. That at least gave a speed up. Your best bet might actually be to move to rstanarm for these kind of models.

It’s a pretty fundamental experimental design - used in 90% of clinical trials - I’m surprised it is so hard in Stan. Am I specifying the likelihood correctly, are there too many parameters?

My guess is that you are having too many degrees of freedom by defining all the ‘fixed’ effects as hierarchical parameters. The extra parameters are all the sigmas and you only have three groups to identify these effects.

Thanks that’s a good tip. I will play around and see what happens.