Strategies for Improving (Growth) Mixture Model Convergence

I’m trying to estimate what I think would be called a latent class linear mixed model (or possibly a growth mixture model) and having trouble getting it to converge.

I have repeated measures of a binary outcome – past-six-month heroin use – from 2,923 participants totaling 62,065 observations. The goal is to describe common life-course trajectories of heroin use: in the model there are K latent trajectory classes, and within each class the conditional log odds of use is modeled as a cubic function of age with a multilevel intercept (level 1 --> observation; level 2 --> person). There is what I think is a modestly informative student-t prior on the grand model intercept in each class, to break the symmetry between classes – the priors are supposed to correspond to a weak prior belief that the probability of use in one class will be 35% and 65% for the other class at age 45 (these are plausible values in this cohort).

Here is the BRMS code used to generate the 2-class model (chains = 0 because I only use BRMS to help with the code, I run the models in CmdStan because that works better on the computing cluster I’m using):

# 2 Class Model
mix2 <- mixture(bernoulli,bernoulli)
pri2 <- c(
  prior(
    student_t(3,-.619,5),Intercept,dpar = mu1 #35%
  ),
  prior(
    student_t(3,.619,5),Intercept,dpar = mu2 # 65%
  )
)
mcmc.2class <- brm(heroin ~ age.std + age2.std + age3.std + female + black + cohort + visit.era + (1|newid),
                   data = alive.complete,
                   family = mix2,
                   prior = pri2,
                   chains = 0
)

Here is the actual Stan code for that model:

// generated with brms 2.5.0
functions { 
} 
data { 
  int<lower=1> N;  // total number of observations 
  int Y[N];  // response variable 
  int<lower=1> K_mu1;  // number of population-level effects 
  matrix[N, K_mu1] X_mu1;  // population-level design matrix 
  int<lower=1> K_mu2;  // number of population-level effects 
  matrix[N, K_mu2] X_mu2;  // population-level design matrix 
  vector[2] con_theta;  // prior concentration 
  // data for group-level effects of ID 1
  int<lower=1> J_1[N];
  int<lower=1> N_1;
  int<lower=1> M_1;
  vector[N] Z_1_mu1_1;
  // data for group-level effects of ID 2
  int<lower=1> J_2[N];
  int<lower=1> N_2;
  int<lower=1> M_2;
  vector[N] Z_2_mu2_1;
  int prior_only;  // should the likelihood be ignored? 
} 
transformed data { 
  int Kc_mu1 = K_mu1 - 1; 
  matrix[N, K_mu1 - 1] Xc_mu1;  // centered version of X_mu1 
  vector[K_mu1 - 1] means_X_mu1;  // column means of X_mu1 before centering 
  int Kc_mu2 = K_mu2 - 1; 
  matrix[N, K_mu2 - 1] Xc_mu2;  // centered version of X_mu2 
  vector[K_mu2 - 1] means_X_mu2;  // column means of X_mu2 before centering 
  for (i in 2:K_mu1) { 
    means_X_mu1[i - 1] = mean(X_mu1[, i]); 
    Xc_mu1[, i - 1] = X_mu1[, i] - means_X_mu1[i - 1]; 
  } 
  for (i in 2:K_mu2) { 
    means_X_mu2[i - 1] = mean(X_mu2[, i]); 
    Xc_mu2[, i - 1] = X_mu2[, i] - means_X_mu2[i - 1]; 
  } 
} 
parameters { 
  vector[Kc_mu1] b_mu1;  // population-level effects 
  vector[Kc_mu2] b_mu2;  // population-level effects 
  simplex[2] theta;  // mixing proportions 
  ordered[2] ordered_Intercept;  // to identify mixtures 
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  vector[N_1] z_1[M_1];  // unscaled group-level effects
  vector<lower=0>[M_2] sd_2;  // group-level standard deviations
  vector[N_2] z_2[M_2];  // unscaled group-level effects
} 
transformed parameters { 
  // identify mixtures via ordering of the intercepts 
  real temp_mu1_Intercept = ordered_Intercept[1]; 
  // identify mixtures via ordering of the intercepts 
  real temp_mu2_Intercept = ordered_Intercept[2]; 
  // mixing proportions 
  real<lower=0,upper=1> theta1 = theta[1]; 
  real<lower=0,upper=1> theta2 = theta[2]; 
  // group-level effects 
  vector[N_1] r_1_mu1_1 = sd_1[1] * (z_1[1]);
  // group-level effects 
  vector[N_2] r_2_mu2_1 = sd_2[1] * (z_2[1]);
} 
model { 
  vector[N] mu1 = temp_mu1_Intercept + Xc_mu1 * b_mu1;
  vector[N] mu2 = temp_mu2_Intercept + Xc_mu2 * b_mu2;
  for (n in 1:N) { 
    mu1[n] += r_1_mu1_1[J_1[n]] * Z_1_mu1_1[n];
    mu2[n] += r_2_mu2_1[J_2[n]] * Z_2_mu2_1[n];
  } 
  // priors including all constants 
  target += student_t_lpdf(temp_mu1_Intercept | 3, -0.619, 5); 
  target += student_t_lpdf(temp_mu2_Intercept | 3, 0.619, 5); 
  target += dirichlet_lpdf(theta | con_theta); 
  target += student_t_lpdf(sd_1 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10); 
  target += normal_lpdf(z_1[1] | 0, 1);
  target += student_t_lpdf(sd_2 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10); 
  target += normal_lpdf(z_2[1] | 0, 1);
  // likelihood including all constants 
  if (!prior_only) { 
    for (n in 1:N) {
      real ps[2];
      ps[1] = log(theta1) + bernoulli_logit_lpmf(Y[n] | mu1[n]);
      ps[2] = log(theta2) + bernoulli_logit_lpmf(Y[n] | mu2[n]);
      target += log_sum_exp(ps);
    }
  } 
} 
generated quantities { 
  // actual population-level intercept 
  real b_mu1_Intercept = temp_mu1_Intercept - dot_product(means_X_mu1, b_mu1); 
  // actual population-level intercept 
  real b_mu2_Intercept = temp_mu2_Intercept - dot_product(means_X_mu2, b_mu2); 
} 

Here are the trace plots of the grand intercepts (and a few other covariates) after running 1,000 warmups and 1,000 iterations for three chains. No convergence, and seemingly pretty high autocorrelation.
image

Any suggestions about how to get the model to converge? Some thoughts:

  • Am I even specifying the model correctly?
  • I could make the intercept priors more informative. I’m reluctant to do this, because I really don’t have strong prior beliefs. But maybe there is a way to do this in a way that won’t bias the model?
  • I could run for more warmup or more iterations. But nothing about the traceplot makes me think even a very long run time will get me there.
  • I could thin the chain, but again, that doesn’t look promising, and also I gather that for HMC this is generally not that helpful.
  • I could allow for multilevel age slopes (right now just the intercept is multilevel). This would add many more parameters of course, but I gather that sometimes a more complex model actually converges faster because it fits the data better.
  • I could go back to the drawing board and consider whether this is really an appropriate model for the data.

Thanks

Start with the simplest thing you can think of. If you’re having to answer all those questions simultaneously, then we’re a bridge too far into this model.

What does the mixture here represent?