Mixed/hierarchical logit code optimization

I have a mixed logit model (or categorical logit, hierarchical model with softmax, various other names in different fields) with a very long runtime. After some testing, I thought it was time to seek some help with it. I began with synthetic data and ran the model on my local machine. Everything looked fine. My real dataset has 178,000 records. The model runs with 100 data points and 500/500 warmup/samples. This gives me:

Warning: 1212 of 2000 (61.0%) transitions ended with a divergence.
This may indicate insufficient exploration of the posterior distribution.
Possible remedies include: 
  * Increasing adapt_delta closer to 1 (default is 0.8) 
  * Reparameterizing the model (e.g. using a non-centered parameterization)
  * Using informative or weakly informative prior distributions 

788 of 2000 (39.0%) transitions hit the maximum treedepth limit of 10 or 2^10-1 leapfrog steps.
Trajectories that are prematurely terminated due to this limit will result in slow exploration.
Increasing the max_treedepth limit can avoid this at the expense of more computation.
If increasing max_treedepth does not remove warnings, try to reparameterize the model.

I’m running the full model on an HPC. I allocate 16 cores and 36 GB RAM, so it’s running with 4 parallel chains and 4 threads per chain. It’s been running for 18 hours and hasn’t finished 100 iterations yet. There are nodes on the HPC with 36 cores. I could slightly modify my priors, but they’re already quite tight N(0,2.5). The runs with 100 data points give me the following warning:

The current Metropolis proposal is about to be rejected because of the following issue:
Chain 4 Exception: Exception: categorical_logit_lpmf: log odds parameter[1] is -nan, but must be finite!

This disappears after the first 100 iterations, so I thought it should be ok. The softmax function has exponentials, which will obviously inflate some of the initial values. I’ve looked at individual outputs and there’s just a few observations given large initial values.

I’ve tried VI using meanfield on a simplified model (removing (1|trippurp) and simplifying to (1 + hhfaminc|personid)) but that ran out of RAM (allocating about 36 GB. I gave it 8 cores but didn’t run on parallel threads on the final test to minimize RAM use).

I’ve written mixed logit models in Stan before (partially based on discussion here), but this is a much larger model/dataset. I’ve estimated other hierarchical models in Stan where the runtime with brms was much faster than using what I wrote, so I figured it would be a good starting point here, too.

I have the following script (using brms to generate the data inputs because I generate the initial Stan code using brms, modify it a bit, and run it with cmdstanr). The brms model is based on one here. I realize this is a lot of code to review. Most of it is variable definitions and the model is also given in the R code via a brms specification. There are two hierarchies in the model: 1) individuals, for whom I include contextual effects and 2) trip purpose where an individual can have multiple trip purposes.

model.1.data <-
  standata(brm(data = mod,
      family = categorical(link = logit, refcat = NA),
      bf(trptrans ~ 1,
         nlf(mu0 ~ btc * (a1 + b1tt * travelTimeDrive + travelCostDrive)),
         nlf(mu1 ~ btc * (a2 + b2tt * travelTimeWalk)),
         nlf(mu2 ~ btc * (a3 + b3tt * travelTimeBike)),
         nlf(mu3 ~ btc * (b4tt * travelTimeTransit)),
         mvbind(btc + a1 + a2 + a3 + b1tt + b2tt + b3tt + b4tt) ~ 1 + (1 + race_2 + race_3 + race_4 + race_5 + hhfaminc|personid) + (1|trippurp)),
      prior = c(prior(normal(0, 2.5), class = b, nlpar = a1),
                prior(normal(0, 2.5), class = b, nlpar = a2),
                prior(normal(0, 2.5), class = b, nlpar = a3),
                prior(normal(0, 2.5), class = b, nlpar = b1tt),
                prior(normal(0, 2.5), class = b, nlpar = btc),
                prior(normal(0, 2.5), class = b, nlpar = b2tt),
                prior(normal(0, 2.5), class = b, nlpar = b3tt),
                prior(normal(0, 2.5), class = b, nlpar = b4tt)),
      empty=T,
      backend = "cmdstanr",
      threads = threading(4)))

model.1 = cmdstan_model("equity_brms.stan", cpp_options = list(stan_threads = TRUE))

model_fit = model.1$sample(data = model.1.data,
			seed = 24567,
			iter_warmup  = 1000,
			iter_sampling =1000,
  			chains = 4,
  			parallel_chains = 4,
			  threads_per_chain = 4)

The Stan model (with minor modifications from that generated from brms) is:

functions {
 /* compute correlated group-level effects
  * Args: 
  *   z: matrix of unscaled group-level effects
  *   SD: vector of standard deviation parameters
  *   L: cholesky factor correlation matrix
  * Returns: 
  *   matrix of scaled group-level effects
  */ 
  matrix scale_r_cor(matrix z, vector SD, matrix L) {
    // r is stored in another dimension order than z
    return transpose(diag_pre_multiply(SD, L) * z);
  }
  /* integer sequence of values
   * Args: 
   *   start: starting integer
   *   end: ending integer
   * Returns: 
   *   an integer sequence from start to end
   */ 
  int[] sequence(int start, int end) { 
    int seq[end - start + 1];
    for (n in 1:num_elements(seq)) {
      seq[n] = n + start - 1;
    }
    return seq; 
  } 
  // compute partial sums of the log-likelihood
  real partial_log_lik_lpmf(int[] seq, int start, int end, data int ncat, data int[] Y, data matrix X_btc, vector b_btc, data matrix X_a1, vector b_a1, data matrix X_a2, vector b_a2, data matrix X_a3, vector b_a3, data matrix X_b1tt, vector b_b1tt, data matrix X_b2tt, vector b_b2tt, data matrix X_b3tt, vector b_b3tt, data matrix X_b4tt, vector b_b4tt, data vector C_mu0_1, data vector C_mu0_2, data vector C_mu1_1, data vector C_mu2_1, data vector C_mu3_1, data int[] J_1, data vector Z_1_btc_1, data vector Z_1_btc_2, data vector Z_1_btc_3, data vector Z_1_btc_4, data vector Z_1_btc_5, data vector Z_1_btc_6, vector r_1_btc_1, vector r_1_btc_2, vector r_1_btc_3, vector r_1_btc_4, vector r_1_btc_5, vector r_1_btc_6, data int[] J_2, data vector Z_2_btc_1, vector r_2_btc_1, data int[] J_3, data vector Z_3_a1_1, data vector Z_3_a1_2, data vector Z_3_a1_3, data vector Z_3_a1_4, data vector Z_3_a1_5, data vector Z_3_a1_6, vector r_3_a1_1, vector r_3_a1_2, vector r_3_a1_3, vector r_3_a1_4, vector r_3_a1_5, vector r_3_a1_6, data int[] J_4, data vector Z_4_a1_1, vector r_4_a1_1, data int[] J_5, data vector Z_5_a2_1, data vector Z_5_a2_2, data vector Z_5_a2_3, data vector Z_5_a2_4, data vector Z_5_a2_5, data vector Z_5_a2_6, vector r_5_a2_1, vector r_5_a2_2, vector r_5_a2_3, vector r_5_a2_4, vector r_5_a2_5, vector r_5_a2_6, data int[] J_6, data vector Z_6_a2_1, vector r_6_a2_1, data int[] J_7, data vector Z_7_a3_1, data vector Z_7_a3_2, data vector Z_7_a3_3, data vector Z_7_a3_4, data vector Z_7_a3_5, data vector Z_7_a3_6, vector r_7_a3_1, vector r_7_a3_2, vector r_7_a3_3, vector r_7_a3_4, vector r_7_a3_5, vector r_7_a3_6, data int[] J_8, data vector Z_8_a3_1, vector r_8_a3_1, data int[] J_9, data vector Z_9_b1tt_1, data vector Z_9_b1tt_2, data vector Z_9_b1tt_3, data vector Z_9_b1tt_4, data vector Z_9_b1tt_5, data vector Z_9_b1tt_6, vector r_9_b1tt_1, vector r_9_b1tt_2, vector r_9_b1tt_3, vector r_9_b1tt_4, vector r_9_b1tt_5, vector r_9_b1tt_6, data int[] J_10, data vector Z_10_b1tt_1, vector r_10_b1tt_1, data int[] J_11, data vector Z_11_b2tt_1, data vector Z_11_b2tt_2, data vector Z_11_b2tt_3, data vector Z_11_b2tt_4, data vector Z_11_b2tt_5, data vector Z_11_b2tt_6, vector r_11_b2tt_1, vector r_11_b2tt_2, vector r_11_b2tt_3, vector r_11_b2tt_4, vector r_11_b2tt_5, vector r_11_b2tt_6, data int[] J_12, data vector Z_12_b2tt_1, vector r_12_b2tt_1, data int[] J_13, data vector Z_13_b3tt_1, data vector Z_13_b3tt_2, data vector Z_13_b3tt_3, data vector Z_13_b3tt_4, data vector Z_13_b3tt_5, data vector Z_13_b3tt_6, vector r_13_b3tt_1, vector r_13_b3tt_2, vector r_13_b3tt_3, vector r_13_b3tt_4, vector r_13_b3tt_5, vector r_13_b3tt_6, data int[] J_14, data vector Z_14_b3tt_1, vector r_14_b3tt_1, data int[] J_15, data vector Z_15_b4tt_1, data vector Z_15_b4tt_2, data vector Z_15_b4tt_3, data vector Z_15_b4tt_4, data vector Z_15_b4tt_5, data vector Z_15_b4tt_6, vector r_15_b4tt_1, vector r_15_b4tt_2, vector r_15_b4tt_3, vector r_15_b4tt_4, vector r_15_b4tt_5, vector r_15_b4tt_6, data int[] J_16, data vector Z_16_b4tt_1, vector r_16_b4tt_1) {
    real ptarget = 0;
    int N = end - start + 1;
    // initialize linear predictor term
    vector[N] nlp_btc = X_btc[start:end] * b_btc;
    // initialize linear predictor term
    vector[N] nlp_a1 = X_a1[start:end] * b_a1;
    // initialize linear predictor term
    vector[N] nlp_a2 = X_a2[start:end] * b_a2;
    // initialize linear predictor term
    vector[N] nlp_a3 = X_a3[start:end] * b_a3;
    // initialize linear predictor term
    vector[N] nlp_b1tt = X_b1tt[start:end] * b_b1tt;
    // initialize linear predictor term
    vector[N] nlp_b2tt = X_b2tt[start:end] * b_b2tt;
    // initialize linear predictor term
    vector[N] nlp_b3tt = X_b3tt[start:end] * b_b3tt;
    // initialize linear predictor term
    vector[N] nlp_b4tt = X_b4tt[start:end] * b_b4tt;
    // initialize non-linear predictor term
    vector[N] mu0;
    // initialize non-linear predictor term
    vector[N] mu1;
    // initialize non-linear predictor term
    vector[N] mu2;
    // initialize non-linear predictor term
    vector[N] mu3;
    // linear predictor matrix
    vector[ncat] mu[N];
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      nlp_btc[n] += exp(r_1_btc_1[J_1[nn]] * Z_1_btc_1[nn] + r_1_btc_2[J_1[nn]] * Z_1_btc_2[nn] + r_1_btc_3[J_1[nn]] * Z_1_btc_3[nn] + r_1_btc_4[J_1[nn]] * Z_1_btc_4[nn] + r_1_btc_5[J_1[nn]] * Z_1_btc_5[nn] + r_1_btc_6[J_1[nn]] * Z_1_btc_6[nn] + r_2_btc_1[J_2[nn]] * Z_2_btc_1[nn]);
    }
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      nlp_a1[n] += r_3_a1_1[J_3[nn]] * Z_3_a1_1[nn] + r_3_a1_2[J_3[nn]] * Z_3_a1_2[nn] + r_3_a1_3[J_3[nn]] * Z_3_a1_3[nn] + r_3_a1_4[J_3[nn]] * Z_3_a1_4[nn] + r_3_a1_5[J_3[nn]] * Z_3_a1_5[nn] + r_3_a1_6[J_3[nn]] * Z_3_a1_6[nn] + r_4_a1_1[J_4[nn]] * Z_4_a1_1[nn];
    }
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      nlp_a2[n] += r_5_a2_1[J_5[nn]] * Z_5_a2_1[nn] + r_5_a2_2[J_5[nn]] * Z_5_a2_2[nn] + r_5_a2_3[J_5[nn]] * Z_5_a2_3[nn] + r_5_a2_4[J_5[nn]] * Z_5_a2_4[nn] + r_5_a2_5[J_5[nn]] * Z_5_a2_5[nn] + r_5_a2_6[J_5[nn]] * Z_5_a2_6[nn] + r_6_a2_1[J_6[nn]] * Z_6_a2_1[nn];
    }
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      nlp_a3[n] += r_7_a3_1[J_7[nn]] * Z_7_a3_1[nn] + r_7_a3_2[J_7[nn]] * Z_7_a3_2[nn] + r_7_a3_3[J_7[nn]] * Z_7_a3_3[nn] + r_7_a3_4[J_7[nn]] * Z_7_a3_4[nn] + r_7_a3_5[J_7[nn]] * Z_7_a3_5[nn] + r_7_a3_6[J_7[nn]] * Z_7_a3_6[nn] + r_8_a3_1[J_8[nn]] * Z_8_a3_1[nn];
    }
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      nlp_b1tt[n] += r_9_b1tt_1[J_9[nn]] * Z_9_b1tt_1[nn] + r_9_b1tt_2[J_9[nn]] * Z_9_b1tt_2[nn] + r_9_b1tt_3[J_9[nn]] * Z_9_b1tt_3[nn] + r_9_b1tt_4[J_9[nn]] * Z_9_b1tt_4[nn] + r_9_b1tt_5[J_9[nn]] * Z_9_b1tt_5[nn] + r_9_b1tt_6[J_9[nn]] * Z_9_b1tt_6[nn] + r_10_b1tt_1[J_10[nn]] * Z_10_b1tt_1[nn];
    }
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      nlp_b2tt[n] += r_11_b2tt_1[J_11[nn]] * Z_11_b2tt_1[nn] + r_11_b2tt_2[J_11[nn]] * Z_11_b2tt_2[nn] + r_11_b2tt_3[J_11[nn]] * Z_11_b2tt_3[nn] + r_11_b2tt_4[J_11[nn]] * Z_11_b2tt_4[nn] + r_11_b2tt_5[J_11[nn]] * Z_11_b2tt_5[nn] + r_11_b2tt_6[J_11[nn]] * Z_11_b2tt_6[nn] + r_12_b2tt_1[J_12[nn]] * Z_12_b2tt_1[nn];
    }
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      nlp_b3tt[n] += r_13_b3tt_1[J_13[nn]] * Z_13_b3tt_1[nn] + r_13_b3tt_2[J_13[nn]] * Z_13_b3tt_2[nn] + r_13_b3tt_3[J_13[nn]] * Z_13_b3tt_3[nn] + r_13_b3tt_4[J_13[nn]] * Z_13_b3tt_4[nn] + r_13_b3tt_5[J_13[nn]] * Z_13_b3tt_5[nn] + r_13_b3tt_6[J_13[nn]] * Z_13_b3tt_6[nn] + r_14_b3tt_1[J_14[nn]] * Z_14_b3tt_1[nn];
    }
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      nlp_b4tt[n] += r_15_b4tt_1[J_15[nn]] * Z_15_b4tt_1[nn] + r_15_b4tt_2[J_15[nn]] * Z_15_b4tt_2[nn] + r_15_b4tt_3[J_15[nn]] * Z_15_b4tt_3[nn] + r_15_b4tt_4[J_15[nn]] * Z_15_b4tt_4[nn] + r_15_b4tt_5[J_15[nn]] * Z_15_b4tt_5[nn] + r_15_b4tt_6[J_15[nn]] * Z_15_b4tt_6[nn] + r_16_b4tt_1[J_16[nn]] * Z_16_b4tt_1[nn];
    }
    for (n in 1:N) {
      int nn = n + start - 1;
      // compute non-linear predictor values
      mu0[n] = nlp_btc[n] * (nlp_a1[n] + nlp_b1tt[n] * C_mu0_1[nn] + C_mu0_2[nn]);
    }
    for (n in 1:N) {
      int nn = n + start - 1;
      // compute non-linear predictor values
      mu1[n] = nlp_btc[n] * (nlp_a2[n] + nlp_b2tt[n] * C_mu1_1[nn]);
    }
    for (n in 1:N) {
      int nn = n + start - 1;
      // compute non-linear predictor values
      mu2[n] = nlp_btc[n] * (nlp_a3[n] + nlp_b3tt[n] * C_mu2_1[nn]);
    }
    for (n in 1:N) {
      int nn = n + start - 1;
      // compute non-linear predictor values
      mu3[n] = nlp_btc[n] * (nlp_b4tt[n] * C_mu3_1[nn]);
    }
    
    for (n in 1:N) {
      mu[n] = transpose([mu0[n], mu1[n], mu2[n], mu3[n]]);
    }
    for (n in 1:N) {
      int nn = n + start - 1;
      ptarget += categorical_logit_lpmf(Y[nn] | mu[n]);
    }
    return ptarget;
  }
}
data {
  int<lower=1> N;  // total number of observations
  int<lower=2> ncat;  // number of categories
  int Y[N];  // response variable
  int<lower=1> K_btc;  // number of population-level effects
  matrix[N, K_btc] X_btc;  // population-level design matrix
  int<lower=1> K_a1;  // number of population-level effects
  matrix[N, K_a1] X_a1;  // population-level design matrix
  int<lower=1> K_a2;  // number of population-level effects
  matrix[N, K_a2] X_a2;  // population-level design matrix
  int<lower=1> K_a3;  // number of population-level effects
  matrix[N, K_a3] X_a3;  // population-level design matrix
  int<lower=1> K_b1tt;  // number of population-level effects
  matrix[N, K_b1tt] X_b1tt;  // population-level design matrix
  int<lower=1> K_b2tt;  // number of population-level effects
  matrix[N, K_b2tt] X_b2tt;  // population-level design matrix
  int<lower=1> K_b3tt;  // number of population-level effects
  matrix[N, K_b3tt] X_b3tt;  // population-level design matrix
  int<lower=1> K_b4tt;  // number of population-level effects
  matrix[N, K_b4tt] X_b4tt;  // population-level design matrix
  // covariate vectors for non-linear functions
  vector[N] C_mu0_1;
  vector[N] C_mu0_2;
  // covariate vectors for non-linear functions
  vector[N] C_mu1_1;
  // covariate vectors for non-linear functions
  vector[N] C_mu2_1;
  // covariate vectors for non-linear functions
  vector[N] C_mu3_1;
  int grainsize;  // grainsize for threading
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_1_btc_1;
  vector[N] Z_1_btc_2;
  vector[N] Z_1_btc_3;
  vector[N] Z_1_btc_4;
  vector[N] Z_1_btc_5;
  vector[N] Z_1_btc_6;
  int<lower=1> NC_1;  // number of group-level correlations
  // data for group-level effects of ID 2
  int<lower=1> N_2;  // number of grouping levels
  int<lower=1> M_2;  // number of coefficients per level
  int<lower=1> J_2[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_2_btc_1;
  // data for group-level effects of ID 3
  int<lower=1> N_3;  // number of grouping levels
  int<lower=1> M_3;  // number of coefficients per level
  int<lower=1> J_3[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_3_a1_1;
  vector[N] Z_3_a1_2;
  vector[N] Z_3_a1_3;
  vector[N] Z_3_a1_4;
  vector[N] Z_3_a1_5;
  vector[N] Z_3_a1_6;
  int<lower=1> NC_3;  // number of group-level correlations
  // data for group-level effects of ID 4
  int<lower=1> N_4;  // number of grouping levels
  int<lower=1> M_4;  // number of coefficients per level
  int<lower=1> J_4[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_4_a1_1;
  // data for group-level effects of ID 5
  int<lower=1> N_5;  // number of grouping levels
  int<lower=1> M_5;  // number of coefficients per level
  int<lower=1> J_5[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_5_a2_1;
  vector[N] Z_5_a2_2;
  vector[N] Z_5_a2_3;
  vector[N] Z_5_a2_4;
  vector[N] Z_5_a2_5;
  vector[N] Z_5_a2_6;
  int<lower=1> NC_5;  // number of group-level correlations
  // data for group-level effects of ID 6
  int<lower=1> N_6;  // number of grouping levels
  int<lower=1> M_6;  // number of coefficients per level
  int<lower=1> J_6[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_6_a2_1;
  // data for group-level effects of ID 7
  int<lower=1> N_7;  // number of grouping levels
  int<lower=1> M_7;  // number of coefficients per level
  int<lower=1> J_7[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_7_a3_1;
  vector[N] Z_7_a3_2;
  vector[N] Z_7_a3_3;
  vector[N] Z_7_a3_4;
  vector[N] Z_7_a3_5;
  vector[N] Z_7_a3_6;
  int<lower=1> NC_7;  // number of group-level correlations
  // data for group-level effects of ID 8
  int<lower=1> N_8;  // number of grouping levels
  int<lower=1> M_8;  // number of coefficients per level
  int<lower=1> J_8[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_8_a3_1;
  // data for group-level effects of ID 9
  int<lower=1> N_9;  // number of grouping levels
  int<lower=1> M_9;  // number of coefficients per level
  int<lower=1> J_9[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_9_b1tt_1;
  vector[N] Z_9_b1tt_2;
  vector[N] Z_9_b1tt_3;
  vector[N] Z_9_b1tt_4;
  vector[N] Z_9_b1tt_5;
  vector[N] Z_9_b1tt_6;
  int<lower=1> NC_9;  // number of group-level correlations
  // data for group-level effects of ID 10
  int<lower=1> N_10;  // number of grouping levels
  int<lower=1> M_10;  // number of coefficients per level
  int<lower=1> J_10[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_10_b1tt_1;
  // data for group-level effects of ID 11
  int<lower=1> N_11;  // number of grouping levels
  int<lower=1> M_11;  // number of coefficients per level
  int<lower=1> J_11[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_11_b2tt_1;
  vector[N] Z_11_b2tt_2;
  vector[N] Z_11_b2tt_3;
  vector[N] Z_11_b2tt_4;
  vector[N] Z_11_b2tt_5;
  vector[N] Z_11_b2tt_6;
  int<lower=1> NC_11;  // number of group-level correlations
  // data for group-level effects of ID 12
  int<lower=1> N_12;  // number of grouping levels
  int<lower=1> M_12;  // number of coefficients per level
  int<lower=1> J_12[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_12_b2tt_1;
  // data for group-level effects of ID 13
  int<lower=1> N_13;  // number of grouping levels
  int<lower=1> M_13;  // number of coefficients per level
  int<lower=1> J_13[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_13_b3tt_1;
  vector[N] Z_13_b3tt_2;
  vector[N] Z_13_b3tt_3;
  vector[N] Z_13_b3tt_4;
  vector[N] Z_13_b3tt_5;
  vector[N] Z_13_b3tt_6;
  int<lower=1> NC_13;  // number of group-level correlations
  // data for group-level effects of ID 14
  int<lower=1> N_14;  // number of grouping levels
  int<lower=1> M_14;  // number of coefficients per level
  int<lower=1> J_14[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_14_b3tt_1;
  // data for group-level effects of ID 15
  int<lower=1> N_15;  // number of grouping levels
  int<lower=1> M_15;  // number of coefficients per level
  int<lower=1> J_15[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_15_b4tt_1;
  vector[N] Z_15_b4tt_2;
  vector[N] Z_15_b4tt_3;
  vector[N] Z_15_b4tt_4;
  vector[N] Z_15_b4tt_5;
  vector[N] Z_15_b4tt_6;
  int<lower=1> NC_15;  // number of group-level correlations
  // data for group-level effects of ID 16
  int<lower=1> N_16;  // number of grouping levels
  int<lower=1> M_16;  // number of coefficients per level
  int<lower=1> J_16[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_16_b4tt_1;
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int seq[N] = sequence(1, N);
}
parameters {
  vector[K_btc] b_btc;  // population-level effects
  vector[K_a1] b_a1;  // population-level effects
  vector[K_a2] b_a2;  // population-level effects
  vector[K_a3] b_a3;  // population-level effects
  vector[K_b1tt] b_b1tt;  // population-level effects
  vector[K_b2tt] b_b2tt;  // population-level effects
  vector[K_b3tt] b_b3tt;  // population-level effects
  vector[K_b4tt] b_b4tt;  // population-level effects
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_1;  // cholesky factor of correlation matrix
  vector<lower=0>[M_2] sd_2;  // group-level standard deviations
  vector[N_2] z_2[M_2];  // standardized group-level effects
  vector<lower=0>[M_3] sd_3;  // group-level standard deviations
  matrix[M_3, N_3] z_3;  // standardized group-level effects
  cholesky_factor_corr[M_3] L_3;  // cholesky factor of correlation matrix
  vector<lower=0>[M_4] sd_4;  // group-level standard deviations
  vector[N_4] z_4[M_4];  // standardized group-level effects
  vector<lower=0>[M_5] sd_5;  // group-level standard deviations
  matrix[M_5, N_5] z_5;  // standardized group-level effects
  cholesky_factor_corr[M_5] L_5;  // cholesky factor of correlation matrix
  vector<lower=0>[M_6] sd_6;  // group-level standard deviations
  vector[N_6] z_6[M_6];  // standardized group-level effects
  vector<lower=0>[M_7] sd_7;  // group-level standard deviations
  matrix[M_7, N_7] z_7;  // standardized group-level effects
  cholesky_factor_corr[M_7] L_7;  // cholesky factor of correlation matrix
  vector<lower=0>[M_8] sd_8;  // group-level standard deviations
  vector[N_8] z_8[M_8];  // standardized group-level effects
  vector<lower=0>[M_9] sd_9;  // group-level standard deviations
  matrix[M_9, N_9] z_9;  // standardized group-level effects
  cholesky_factor_corr[M_9] L_9;  // cholesky factor of correlation matrix
  vector<lower=0>[M_10] sd_10;  // group-level standard deviations
  vector[N_10] z_10[M_10];  // standardized group-level effects
  vector<lower=0>[M_11] sd_11;  // group-level standard deviations
  matrix[M_11, N_11] z_11;  // standardized group-level effects
  cholesky_factor_corr[M_11] L_11;  // cholesky factor of correlation matrix
  vector<lower=0>[M_12] sd_12;  // group-level standard deviations
  vector[N_12] z_12[M_12];  // standardized group-level effects
  vector<lower=0>[M_13] sd_13;  // group-level standard deviations
  matrix[M_13, N_13] z_13;  // standardized group-level effects
  cholesky_factor_corr[M_13] L_13;  // cholesky factor of correlation matrix
  vector<lower=0>[M_14] sd_14;  // group-level standard deviations
  vector[N_14] z_14[M_14];  // standardized group-level effects
  vector<lower=0>[M_15] sd_15;  // group-level standard deviations
  matrix[M_15, N_15] z_15;  // standardized group-level effects
  cholesky_factor_corr[M_15] L_15;  // cholesky factor of correlation matrix
  vector<lower=0>[M_16] sd_16;  // group-level standard deviations
  vector[N_16] z_16[M_16];  // standardized group-level effects
}
transformed parameters {
  matrix[N_1, M_1] r_1;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_1] r_1_btc_1;
  vector[N_1] r_1_btc_2;
  vector[N_1] r_1_btc_3;
  vector[N_1] r_1_btc_4;
  vector[N_1] r_1_btc_5;
  vector[N_1] r_1_btc_6;
  vector[N_2] r_2_btc_1;  // actual group-level effects
  matrix[N_3, M_3] r_3;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_3] r_3_a1_1;
  vector[N_3] r_3_a1_2;
  vector[N_3] r_3_a1_3;
  vector[N_3] r_3_a1_4;
  vector[N_3] r_3_a1_5;
  vector[N_3] r_3_a1_6;
  vector[N_4] r_4_a1_1;  // actual group-level effects
  matrix[N_5, M_5] r_5;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_5] r_5_a2_1;
  vector[N_5] r_5_a2_2;
  vector[N_5] r_5_a2_3;
  vector[N_5] r_5_a2_4;
  vector[N_5] r_5_a2_5;
  vector[N_5] r_5_a2_6;
  vector[N_6] r_6_a2_1;  // actual group-level effects
  matrix[N_7, M_7] r_7;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_7] r_7_a3_1;
  vector[N_7] r_7_a3_2;
  vector[N_7] r_7_a3_3;
  vector[N_7] r_7_a3_4;
  vector[N_7] r_7_a3_5;
  vector[N_7] r_7_a3_6;
  vector[N_8] r_8_a3_1;  // actual group-level effects
  matrix[N_9, M_9] r_9;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_9] r_9_b1tt_1;
  vector[N_9] r_9_b1tt_2;
  vector[N_9] r_9_b1tt_3;
  vector[N_9] r_9_b1tt_4;
  vector[N_9] r_9_b1tt_5;
  vector[N_9] r_9_b1tt_6;
  vector[N_10] r_10_b1tt_1;  // actual group-level effects
  matrix[N_11, M_11] r_11;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_11] r_11_b2tt_1;
  vector[N_11] r_11_b2tt_2;
  vector[N_11] r_11_b2tt_3;
  vector[N_11] r_11_b2tt_4;
  vector[N_11] r_11_b2tt_5;
  vector[N_11] r_11_b2tt_6;
  vector[N_12] r_12_b2tt_1;  // actual group-level effects
  matrix[N_13, M_13] r_13;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_13] r_13_b3tt_1;
  vector[N_13] r_13_b3tt_2;
  vector[N_13] r_13_b3tt_3;
  vector[N_13] r_13_b3tt_4;
  vector[N_13] r_13_b3tt_5;
  vector[N_13] r_13_b3tt_6;
  vector[N_14] r_14_b3tt_1;  // actual group-level effects
  matrix[N_15, M_15] r_15;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_15] r_15_b4tt_1;
  vector[N_15] r_15_b4tt_2;
  vector[N_15] r_15_b4tt_3;
  vector[N_15] r_15_b4tt_4;
  vector[N_15] r_15_b4tt_5;
  vector[N_15] r_15_b4tt_6;
  vector[N_16] r_16_b4tt_1;  // actual group-level effects
  // compute actual group-level effects
  r_1 = scale_r_cor(z_1, sd_1, L_1);
  r_1_btc_1 = r_1[, 1];
  r_1_btc_2 = r_1[, 2];
  r_1_btc_3 = r_1[, 3];
  r_1_btc_4 = r_1[, 4];
  r_1_btc_5 = r_1[, 5];
  r_1_btc_6 = r_1[, 6];
  r_2_btc_1 = (sd_2[1] * (z_2[1]));
  // compute actual group-level effects
  r_3 = scale_r_cor(z_3, sd_3, L_3);
  r_3_a1_1 = r_3[, 1];
  r_3_a1_2 = r_3[, 2];
  r_3_a1_3 = r_3[, 3];
  r_3_a1_4 = r_3[, 4];
  r_3_a1_5 = r_3[, 5];
  r_3_a1_6 = r_3[, 6];
  r_4_a1_1 = (sd_4[1] * (z_4[1]));
  // compute actual group-level effects
  r_5 = scale_r_cor(z_5, sd_5, L_5);
  r_5_a2_1 = r_5[, 1];
  r_5_a2_2 = r_5[, 2];
  r_5_a2_3 = r_5[, 3];
  r_5_a2_4 = r_5[, 4];
  r_5_a2_5 = r_5[, 5];
  r_5_a2_6 = r_5[, 6];
  r_6_a2_1 = (sd_6[1] * (z_6[1]));
  // compute actual group-level effects
  r_7 = scale_r_cor(z_7, sd_7, L_7);
  r_7_a3_1 = r_7[, 1];
  r_7_a3_2 = r_7[, 2];
  r_7_a3_3 = r_7[, 3];
  r_7_a3_4 = r_7[, 4];
  r_7_a3_5 = r_7[, 5];
  r_7_a3_6 = r_7[, 6];
  r_8_a3_1 = (sd_8[1] * (z_8[1]));
  // compute actual group-level effects
  r_9 = scale_r_cor(z_9, sd_9, L_9);
  r_9_b1tt_1 = r_9[, 1];
  r_9_b1tt_2 = r_9[, 2];
  r_9_b1tt_3 = r_9[, 3];
  r_9_b1tt_4 = r_9[, 4];
  r_9_b1tt_5 = r_9[, 5];
  r_9_b1tt_6 = r_9[, 6];
  r_10_b1tt_1 = (sd_10[1] * (z_10[1]));
  // compute actual group-level effects
  r_11 = scale_r_cor(z_11, sd_11, L_11);
  r_11_b2tt_1 = r_11[, 1];
  r_11_b2tt_2 = r_11[, 2];
  r_11_b2tt_3 = r_11[, 3];
  r_11_b2tt_4 = r_11[, 4];
  r_11_b2tt_5 = r_11[, 5];
  r_11_b2tt_6 = r_11[, 6];
  r_12_b2tt_1 = (sd_12[1] * (z_12[1]));
  // compute actual group-level effects
  r_13 = scale_r_cor(z_13, sd_13, L_13);
  r_13_b3tt_1 = r_13[, 1];
  r_13_b3tt_2 = r_13[, 2];
  r_13_b3tt_3 = r_13[, 3];
  r_13_b3tt_4 = r_13[, 4];
  r_13_b3tt_5 = r_13[, 5];
  r_13_b3tt_6 = r_13[, 6];
  r_14_b3tt_1 = (sd_14[1] * (z_14[1]));
  // compute actual group-level effects
  r_15 = scale_r_cor(z_15, sd_15, L_15);
  r_15_b4tt_1 = r_15[, 1];
  r_15_b4tt_2 = r_15[, 2];
  r_15_b4tt_3 = r_15[, 3];
  r_15_b4tt_4 = r_15[, 4];
  r_15_b4tt_5 = r_15[, 5];
  r_15_b4tt_6 = r_15[, 6];
  r_16_b4tt_1 = (sd_16[1] * (z_16[1]));
}
model {
  // likelihood including constants
  target += reduce_sum(partial_log_lik_lpmf, seq, grainsize, ncat, Y, X_btc, b_btc, X_a1, b_a1, X_a2, b_a2, X_a3, b_a3, X_b1tt, b_b1tt, X_b2tt, b_b2tt, X_b3tt, b_b3tt, X_b4tt, b_b4tt, C_mu0_1, C_mu0_2, C_mu1_1, C_mu2_1, C_mu3_1, J_1, Z_1_btc_1, Z_1_btc_2, Z_1_btc_3, Z_1_btc_4, Z_1_btc_5, Z_1_btc_6, r_1_btc_1, r_1_btc_2, r_1_btc_3, r_1_btc_4, r_1_btc_5, r_1_btc_6, J_2, Z_2_btc_1, r_2_btc_1, J_3, Z_3_a1_1, Z_3_a1_2, Z_3_a1_3, Z_3_a1_4, Z_3_a1_5, Z_3_a1_6, r_3_a1_1, r_3_a1_2, r_3_a1_3, r_3_a1_4, r_3_a1_5, r_3_a1_6, J_4, Z_4_a1_1, r_4_a1_1, J_5, Z_5_a2_1, Z_5_a2_2, Z_5_a2_3, Z_5_a2_4, Z_5_a2_5, Z_5_a2_6, r_5_a2_1, r_5_a2_2, r_5_a2_3, r_5_a2_4, r_5_a2_5, r_5_a2_6, J_6, Z_6_a2_1, r_6_a2_1, J_7, Z_7_a3_1, Z_7_a3_2, Z_7_a3_3, Z_7_a3_4, Z_7_a3_5, Z_7_a3_6, r_7_a3_1, r_7_a3_2, r_7_a3_3, r_7_a3_4, r_7_a3_5, r_7_a3_6, J_8, Z_8_a3_1, r_8_a3_1, J_9, Z_9_b1tt_1, Z_9_b1tt_2, Z_9_b1tt_3, Z_9_b1tt_4, Z_9_b1tt_5, Z_9_b1tt_6, r_9_b1tt_1, r_9_b1tt_2, r_9_b1tt_3, r_9_b1tt_4, r_9_b1tt_5, r_9_b1tt_6, J_10, Z_10_b1tt_1, r_10_b1tt_1, J_11, Z_11_b2tt_1, Z_11_b2tt_2, Z_11_b2tt_3, Z_11_b2tt_4, Z_11_b2tt_5, Z_11_b2tt_6, r_11_b2tt_1, r_11_b2tt_2, r_11_b2tt_3, r_11_b2tt_4, r_11_b2tt_5, r_11_b2tt_6, J_12, Z_12_b2tt_1, r_12_b2tt_1, J_13, Z_13_b3tt_1, Z_13_b3tt_2, Z_13_b3tt_3, Z_13_b3tt_4, Z_13_b3tt_5, Z_13_b3tt_6, r_13_b3tt_1, r_13_b3tt_2, r_13_b3tt_3, r_13_b3tt_4, r_13_b3tt_5, r_13_b3tt_6, J_14, Z_14_b3tt_1, r_14_b3tt_1, J_15, Z_15_b4tt_1, Z_15_b4tt_2, Z_15_b4tt_3, Z_15_b4tt_4, Z_15_b4tt_5, Z_15_b4tt_6, r_15_b4tt_1, r_15_b4tt_2, r_15_b4tt_3, r_15_b4tt_4, r_15_b4tt_5, r_15_b4tt_6, J_16, Z_16_b4tt_1, r_16_b4tt_1);
  // priors including constants
  target += normal_lpdf(b_btc | 0, 2.5);
  target += normal_lpdf(b_a1 | 0, 2.5);
  target += normal_lpdf(b_a2 | 0, 2.5);
  target += normal_lpdf(b_a3 | 0, 2.5);
  target += normal_lpdf(b_b1tt | 0, 2.5);
  target += normal_lpdf(b_b2tt | 0, 2.5);
  target += normal_lpdf(b_b3tt | 0, 2.5);
  target += normal_lpdf(b_b4tt | 0, 2.5);
  target += student_t_lpdf(sd_1 | 3, 0, 2.5)
    - 6 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(to_vector(z_1));
  target += lkj_corr_cholesky_lpdf(L_1 | 1);
  target += student_t_lpdf(sd_2 | 3, 0, 2.5)
    - 1 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(z_2[1]);
  target += student_t_lpdf(sd_3 | 3, 0, 2.5)
    - 6 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(to_vector(z_3));
  target += lkj_corr_cholesky_lpdf(L_3 | 1);
  target += student_t_lpdf(sd_4 | 3, 0, 2.5)
    - 1 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(z_4[1]);
  target += student_t_lpdf(sd_5 | 3, 0, 2.5)
    - 6 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(to_vector(z_5));
  target += lkj_corr_cholesky_lpdf(L_5 | 1);
  target += student_t_lpdf(sd_6 | 3, 0, 2.5)
    - 1 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(z_6[1]);
  target += student_t_lpdf(sd_7 | 3, 0, 2.5)
    - 6 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(to_vector(z_7));
  target += lkj_corr_cholesky_lpdf(L_7 | 1);
  target += student_t_lpdf(sd_8 | 3, 0, 2.5)
    - 1 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(z_8[1]);
  target += student_t_lpdf(sd_9 | 3, 0, 2.5)
    - 6 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(to_vector(z_9));
  target += lkj_corr_cholesky_lpdf(L_9 | 1);
  target += student_t_lpdf(sd_10 | 3, 0, 2.5)
    - 1 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(z_10[1]);
  target += student_t_lpdf(sd_11 | 3, 0, 2.5)
    - 6 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(to_vector(z_11));
  target += lkj_corr_cholesky_lpdf(L_11 | 1);
  target += student_t_lpdf(sd_12 | 3, 0, 2.5)
    - 1 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(z_12[1]);
  target += student_t_lpdf(sd_13 | 3, 0, 2.5)
    - 6 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(to_vector(z_13));
  target += lkj_corr_cholesky_lpdf(L_13 | 1);
  target += student_t_lpdf(sd_14 | 3, 0, 2.5)
    - 1 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(z_14[1]);
  target += student_t_lpdf(sd_15 | 3, 0, 2.5)
    - 6 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(to_vector(z_15));
  target += lkj_corr_cholesky_lpdf(L_15 | 1);
  target += student_t_lpdf(sd_16 | 3, 0, 2.5)
    - 1 * student_t_lccdf(0 | 3, 0, 2.5);
  target += std_normal_lpdf(z_16[1]);
}
generated quantities {
  // compute group-level correlations
  corr_matrix[M_1] Cor_1 = multiply_lower_tri_self_transpose(L_1);
  vector<lower=-1,upper=1>[NC_1] cor_1;
  // compute group-level correlations
  corr_matrix[M_3] Cor_3 = multiply_lower_tri_self_transpose(L_3);
  vector<lower=-1,upper=1>[NC_3] cor_3;
  // compute group-level correlations
  corr_matrix[M_5] Cor_5 = multiply_lower_tri_self_transpose(L_5);
  vector<lower=-1,upper=1>[NC_5] cor_5;
  // compute group-level correlations
  corr_matrix[M_7] Cor_7 = multiply_lower_tri_self_transpose(L_7);
  vector<lower=-1,upper=1>[NC_7] cor_7;
  // compute group-level correlations
  corr_matrix[M_9] Cor_9 = multiply_lower_tri_self_transpose(L_9);
  vector<lower=-1,upper=1>[NC_9] cor_9;
  // compute group-level correlations
  corr_matrix[M_11] Cor_11 = multiply_lower_tri_self_transpose(L_11);
  vector<lower=-1,upper=1>[NC_11] cor_11;
  // compute group-level correlations
  corr_matrix[M_13] Cor_13 = multiply_lower_tri_self_transpose(L_13);
  vector<lower=-1,upper=1>[NC_13] cor_13;
  // compute group-level correlations
  corr_matrix[M_15] Cor_15 = multiply_lower_tri_self_transpose(L_15);
  vector<lower=-1,upper=1>[NC_15] cor_15;
  // extract upper diagonal of correlation matrix
  for (k in 1:M_1) {
    for (j in 1:(k - 1)) {
      cor_1[choose(k - 1, 2) + j] = Cor_1[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_3) {
    for (j in 1:(k - 1)) {
      cor_3[choose(k - 1, 2) + j] = Cor_3[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_5) {
    for (j in 1:(k - 1)) {
      cor_5[choose(k - 1, 2) + j] = Cor_5[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_7) {
    for (j in 1:(k - 1)) {
      cor_7[choose(k - 1, 2) + j] = Cor_7[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_9) {
    for (j in 1:(k - 1)) {
      cor_9[choose(k - 1, 2) + j] = Cor_9[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_11) {
    for (j in 1:(k - 1)) {
      cor_11[choose(k - 1, 2) + j] = Cor_11[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_13) {
    for (j in 1:(k - 1)) {
      cor_13[choose(k - 1, 2) + j] = Cor_13[j, k];
    }
  }
  // extract upper diagonal of correlation matrix
  for (k in 1:M_15) {
    for (j in 1:(k - 1)) {
      cor_15[choose(k - 1, 2) + j] = Cor_15[j, k];
    }
  }
}

Quick update: After finishing writing the post, the model finished 100 iterations on 1 chain (at 18 hours 1 minute).

Some quick pointers here. If you don’t need normalising constants for things like Bayes factor comparisons, then you can change your likelihoods to use the lupdf/lupmf distributions, which can save some computation. Additionally, you can remove the lccdf calls for a similar reason.

You can also optimise the partial_log_lik function a little by having a single loop over N, rather than multiple. There’s also some further optimisation available in the construction of the non-centered random effects.

This unfortunately looks like one of those situations where the brms-generated code isn’t going to be efficient enough for this kind of model complexity and dataset size.

I’d strongly recommend first simplifying and optimising the Stan code without the parallelisation (using a small subset of data for testing). If you have access to a discrete GPU, then you could likely see even better acceleration than using parallelisation in this case.

If you update your syntax to:

  • Primarily work with large matrix operations, rather than looping over individual vectors
  • Use the vectorised categorical_logit_glm likelihood
  • Call the new Stanc O1 optimisation flag during sampling

Then you have a likelihood with several components that can be GPU-accelerated, which will (generally) provide better speed improvements than the reduce_sum/threading approach

2 Likes

Thanks. That’s very helpful. I’ll take a look and see what I can do to optimize it. I do have GPU access.

Thanks again to @andrjohns for the great feedback. I refactored my code to use matrices/vectors where possible. I also use categorical_logit_glm_lupmf . It’s not prettiest code, but it runs on my local machine and gives expected outputs. The trouble I’m facing now is that the GPU implementation on our HPC seems to be much slower than running it locally. My R/Stan code are provided below for 100 data points. Locally, it runs in 10-15 minutes. On the HPC, it times out after an hour having not even finished 100 iterations. Reviewing previous responses, I haven’t changed my likelihood to use lupdf/lupmf but did it for the likelihood. There’s probably a more elegant way to setup my likelihood and group level effects r_1 etc., but the current version worked for me to build the pieces. The multiplication by b[n,1] is because I’m doing something slightly different than a standard model setup.

The run script is as follows:

#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks=1 
#SBATCH --partition=gpu
#SBATCH --mem=32000
#SBATCH --gres=gpu
#SBATCH --constraint=gpu_32gb
#SBATCH --time=1:00:00
#SBATCH --job-name=R_mtc_base
#SBATCH --error=R_mtc_gpu.%J.err
#SBATCH --output=R_mtc_gpu.%J.out

module load R/4.1
module load cuda/11.4

Rscript mtc_equity_2.R

R script:

model.1 = cmdstan_model("mtc_equity_2.stan", cpp_options = list(stan_opencl = TRUE), stanc_options = list("O1"))

samp_ct = 100L
mod = mod[1:samp_ct,]

data_list = list(N = samp_ct,
            ncat = 4L,
            nvar = 2L,
            ntotvar = 5L,
            Y = as.integer(mod$trptrans),
            x1 = mod%>%select(travelTimeDrive,travelCostDrive),
            x2 = cbind(mod$travelTimeWalk,rep(0,samp_ct)),
            x3 = cbind(mod$travelTimeBike,rep(0,samp_ct)),
            x4 = cbind(mod$travelTimeTransit,rep(-1,samp_ct)),
            N_1 = 62L,
            M_1 = 6L,
            J_1 = as.integer(mod$personid),
            Z_1 = cbind(rep(1,samp_ct),mod%>%select(race_2,race_3,race_4,race_5,hhfaminc)), 
            NC_1 = 15L,
            N_2 = 5L,
            M_2 = 1L,
            J_2 = as.integer(mod$trippurp),
            prior_only = 0L)

model_fit = model.1$sample(data = data_list,
			seed = 24567,
			iter_warmup  = 500,
			iter_sampling =500,
  			chains = 4,
  			parallel_chains = 4)

Stan code:

functions {
 /* compute correlated group-level effects
  * Args: 
  *   z: matrix of unscaled group-level effects
  *   SD: vector of standard deviation parameters
  *   L: cholesky factor correlation matrix
  * Returns: 
  *   matrix of scaled group-level effects
  */ 
  matrix scale_r_cor(matrix z, vector SD, matrix L) {
    // r is stored in another dimension order than z
    return transpose(diag_pre_multiply(SD, L) * z);
  }
}
data {
  int<lower=1> N;  // total number of observations
  int<lower=2> ncat;  // number of categories
  int<lower=2> nvar;  // number of variables for each alternative
  int<lower=2> ntotvar;  // number of total variables
  array[N] int Y;  // response variable
  // covariate matrix for alt1
  matrix[N,nvar] x1;
  // covariate matrix for alt2
  matrix[N,nvar] x2;
  // covariate matrix for alt3
  matrix[N,nvar] x3;
  // covariate matrix for alt4
  matrix[N,nvar] x4;
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  matrix[N, M_1] Z_1;
  int<lower=1> NC_1;  // number of group-level correlations
  // data for group-level effects of ID 2
  int<lower=1> N_2;  // number of grouping levels
  int<lower=1> M_2;  // number of coefficients per level
  int<lower=1> J_2[N];  // grouping indicator per observation
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int<lower=1> M_3 = (M_1+M_2)*(ncat*2);  // total number of sd
}
parameters {
  real b_btc;  // population-level effects
  real b_a2;  // population-level effects
  real b_a3;  // population-level effects
  real b_a4;  // population-level effects
  real b_b1tt;  // population-level effects
  real b_b2tt;  // population-level effects
  real b_b3tt;  // population-level effects
  real b_b4tt;  // population-level effects
  vector<lower=0>[M_3] sd;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_1;  // cholesky factor of correlation matrix
  vector[N_2] z_2[M_2];  // standardized group-level effects
  matrix[M_1, N_1] z_3;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_3;  // cholesky factor of correlation matrix
  vector[N_2] z_4[M_2];  // standardized group-level effects
  matrix[M_1, N_1] z_5;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_5;  // cholesky factor of correlation matrix
  vector[N_2] z_6[M_2];  // standardized group-level effects
  matrix[M_1, N_1] z_7;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_7;  // cholesky factor of correlation matrix
  vector[N_2] z_8[M_2];  // standardized group-level effects
  matrix[M_1, N_1] z_9;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_9;  // cholesky factor of correlation matrix
  vector[N_2] z_10[M_2];  // standardized group-level effects
  matrix[M_1, N_1] z_11;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_11;  // cholesky factor of correlation matrix
  vector[N_2] z_12[M_2];  // standardized group-level effects
  matrix[M_1, N_1] z_13;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_13;  // cholesky factor of correlation matrix
  vector[N_2] z_14[M_2];  // standardized group-level effects
  matrix[M_1, N_1] z_15;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_15;  // cholesky factor of correlation matrix
  vector[N_2] z_16[M_2];  // standardized group-level effects
}
transformed parameters {
  matrix[N_1, M_1] r_1;  // actual group-level effects
  vector[N_2] r_2;  // actual group-level effects
  matrix[N_1, M_1] r_3;  // actual group-level effects
  vector[N_2] r_4;  // actual group-level effects
  matrix[N_1, M_1] r_5;  // actual group-level effects
  vector[N_2] r_6;  // actual group-level effects
  matrix[N_1, M_1] r_7;  // actual group-level effects
  vector[N_2] r_8;  // actual group-level effects
  matrix[N_1, M_1] r_9;  // actual group-level effects
  vector[N_2] r_10;  // actual group-level effects
  matrix[N_1, M_1] r_11;  // actual group-level effects
  vector[N_2] r_12;  // actual group-level effects
  matrix[N_1, M_1] r_13;  // actual group-level effects
  vector[N_2] r_14;  // actual group-level effects
  matrix[N_1, M_1] r_15;  // actual group-level effects
  vector[N_2] r_16;  // actual group-level effects
  
  // compute actual group-level effects
  r_1 = scale_r_cor(z_1, sd[1:M_1], L_1);
  r_2 = (sd[M_1+1] * (z_2[1]));
  // compute actual group-level effects
  r_3 = scale_r_cor(z_3, sd[M_1+2:2*M_1+1], L_3);
  r_4 = (sd[2*M_1+2] * (z_4[1]));
  // compute actual group-level effects
  r_5 = scale_r_cor(z_5, sd[2*M_1+3:3*M_1+2], L_5);
  r_6 = (sd[3*M_1+3] * (z_6[1]));
  // compute actual group-level effects
  r_7 = scale_r_cor(z_7, sd[3*M_1+4:4*M_1+3], L_7);
  r_8 = (sd[4*M_1+4] * (z_8[1]));
  // compute actual group-level effects
  r_9 = scale_r_cor(z_9, sd[4*M_1+5:5*M_1+4], L_9);
  r_10 = (sd[5*M_1+5] * (z_10[1]));
  // compute actual group-level effects
  r_11 = scale_r_cor(z_11, sd[5*M_1+6:6*M_1+5], L_11);
  r_12 = (sd[5*M_1+6] * (z_12[1]));
  // compute actual group-level effects
  r_13 = scale_r_cor(z_13, sd[6*M_1+7:7*M_1+6], L_13);
  r_14 = (sd[6*M_1+7] * (z_14[1]));
  // compute actual group-level effects
  r_15 = scale_r_cor(z_15, sd[7*M_1+8:8*M_1+7], L_15);
  r_16 = (sd[8*M_1+8] * (z_16[1]));
}
model {
  // likelihood including constants
  if (!prior_only) {
    // Define matrices/vector for x, alpha, beta
    matrix[ncat,nvar] x;
    matrix[ncat,nvar] beta;
    vector[ncat] alpha;
    matrix[N,ncat] a = rep_matrix(0,N,ncat);
    a[,2] = rep_vector(b_a2,N);
    a[,3] = rep_vector(b_a3,N);
    a[,4] = rep_vector(b_a4,N);
      
    // initialize linear predictor term
    // Terms are btc, b1tt, b2tt, b3tt, b4tt
    matrix[N,ntotvar] b = rep_matrix(b_btc,N,ntotvar);
    b[,2] = rep_vector(b_b1tt,N);
    b[,3] = rep_vector(b_b2tt,N);
    b[,4] = rep_vector(b_b3tt,N);
    b[,5] = rep_vector(b_b4tt,N);
    
    for (n in 1:N) {
      // add to linear predictor term
      a[n] += [0,r_3[J_1[n]]*Z_1[n]' + r_4[J_2[n]], r_5[J_1[n]]*Z_1[n]' + r_6[J_2[n]], r_7[J_1[n]]*Z_1[n]' + r_8[J_2[n]]];
      // add to linear predictor term
      // Terms are btc, b1tt, b2tt, b3tt, b4tt
      b[n] += [r_1[J_1[n]]*Z_1[n]' + r_2[J_2[n]], r_9[J_1[n]]*Z_1[n]' + r_10[J_2[n]], r_11[J_1[n]]*Z_1[n]' + r_12[J_2[n]],r_13[J_1[n]]*Z_1[n]' + r_14[J_2[n]],r_15[J_1[n]]*Z_1[n]' + r_16[J_2[n]]];
      b[n] = exp(b[n]);
      // Each x and beta is a matrix with dimensions alts x variables
      // Our x will be the time/cost coming in as inputs
      x[1,] = x1[n];
      x[2,] = x2[n];
      x[3,] = x3[n];
      x[4,] = x4[n];
      // Our betas will be the hierarchical slope parameters
      beta[1,] = [b[n,1] * b[n,2], b[n,1]];
      beta[2,] = [b[n,1] * b[n,3], b[n,1]];
      beta[3,] = [b[n,1] * b[n,4], b[n,1]];
      beta[4,] = [b[n,1] * b[n,5], b[n,1]];
      // Our alphas will be the hierarchical intercept parameters
      alpha = [a[n,1], b[n,1] * a[n,2], b[n,1] * a[n,3], b[n,1] * a[n,4]]';
      target += categorical_logit_glm_lupmf(Y[n] | x, alpha, beta');
    }
  }
  // priors including constants
  target += normal_lpdf(b_btc | 0, 2.5);
  target += normal_lpdf(b_a2 | 0, 2.5);
  target += normal_lpdf(b_a3 | 0, 2.5);
  target += normal_lpdf(b_a4 | 0, 2.5);
  target += normal_lpdf(b_b1tt | 0, 2.5);
  target += normal_lpdf(b_b2tt | 0, 2.5);
  target += normal_lpdf(b_b3tt | 0, 2.5);
  target += normal_lpdf(b_b4tt | 0, 2.5);
  target += student_t_lpdf(sd | 3, 0, 2.5);
  target += std_normal_lpdf(to_vector(z_1));
  target += lkj_corr_cholesky_lpdf(L_1 | 1);
  target += std_normal_lpdf(z_2[1]);
  target += std_normal_lpdf(to_vector(z_3));
  target += lkj_corr_cholesky_lpdf(L_3 | 1);
  target += std_normal_lpdf(z_4[1]);
  target += std_normal_lpdf(to_vector(z_5));
  target += lkj_corr_cholesky_lpdf(L_5 | 1);
  target += std_normal_lpdf(z_6[1]);
  target += std_normal_lpdf(to_vector(z_7));
  target += lkj_corr_cholesky_lpdf(L_7 | 1);
  target += std_normal_lpdf(z_8[1]);
  target += std_normal_lpdf(to_vector(z_9));
  target += lkj_corr_cholesky_lpdf(L_9 | 1);
  target += std_normal_lpdf(z_10[1]);
  target += std_normal_lpdf(to_vector(z_11));
  target += lkj_corr_cholesky_lpdf(L_11 | 1);
  target += std_normal_lpdf(z_12[1]);
  target += std_normal_lpdf(to_vector(z_13));
  target += lkj_corr_cholesky_lpdf(L_13 | 1);
  target += std_normal_lpdf(z_14[1]);
  target += std_normal_lpdf(to_vector(z_15));
  target += lkj_corr_cholesky_lpdf(L_15 | 1);
  target += std_normal_lpdf(z_16[1]);
}

I ran the test GPU Stan model given here.

Oddly, it runs much slower on the GPU than CPU (and compared with the time in the vignette). The runtime (for 1/10 the data given in the vignette) is 748.3 seconds whereas with the CPU it’s actually about 1/10 that time. clinfo gives me back the right result for using opencl_ids = c(0, 0). I’ve asked my HPC admin about this problem, too. I thought someone here might be able to help. With things setup properly for the test model, I’ll probably still need to do some further optimization on my real model to improve its runtime on the GPU.

Platform #0: NVIDIA CUDA
 `-- Device #0: Tesla K20m