Mixed/hierarchical logit code optimization

I have a mixed logit model (or categorical logit, hierarchical model with softmax, various other names in different fields) with a very long runtime. After some testing, I thought it was time to seek some help with it. I began with synthetic data and ran the model on my local machine. Everything looked fine. My real dataset has 178,000 records. The model runs with 100 data points and 500/500 warmup/samples. This gives me:

Warning: 1212 of 2000 (61.0%) transitions ended with a divergence.
This may indicate insufficient exploration of the posterior distribution.
Possible remedies include: 
  * Increasing adapt_delta closer to 1 (default is 0.8) 
  * Reparameterizing the model (e.g. using a non-centered parameterization)
  * Using informative or weakly informative prior distributions 

788 of 2000 (39.0%) transitions hit the maximum treedepth limit of 10 or 2^10-1 leapfrog steps.
Trajectories that are prematurely terminated due to this limit will result in slow exploration.
Increasing the max_treedepth limit can avoid this at the expense of more computation.
If increasing max_treedepth does not remove warnings, try to reparameterize the model.

I’m running the full model on an HPC. I allocate 16 cores and 36 GB RAM, so it’s running with 4 parallel chains and 4 threads per chain. It’s been running for 18 hours and hasn’t finished 100 iterations yet. There are nodes on the HPC with 36 cores. I could slightly modify my priors, but they’re already quite tight N(0,2.5). The runs with 100 data points give me the following warning:

The current Metropolis proposal is about to be rejected because of the following issue:
Chain 4 Exception: Exception: categorical_logit_lpmf: log odds parameter[1] is -nan, but must be finite!

This disappears after the first 100 iterations, so I thought it should be ok. The softmax function has exponentials, which will obviously inflate some of the initial values. I’ve looked at individual outputs and there’s just a few observations given large initial values.

I’ve tried VI using meanfield on a simplified model (removing (1|trippurp) and simplifying to (1 + hhfaminc|personid)) but that ran out of RAM (allocating about 36 GB. I gave it 8 cores but didn’t run on parallel threads on the final test to minimize RAM use).

I’ve written mixed logit models in Stan before (partially based on discussion here), but this is a much larger model/dataset. I’ve estimated other hierarchical models in Stan where the runtime with brms was much faster than using what I wrote, so I figured it would be a good starting point here, too.

I have the following script (using brms to generate the data inputs because I generate the initial Stan code using brms, modify it a bit, and run it with cmdstanr). The brms model is based on one here. I realize this is a lot of code to review. Most of it is variable definitions and the model is also given in the R code via a brms specification. There are two hierarchies in the model: 1) individuals, for whom I include contextual effects and 2) trip purpose where an individual can have multiple trip purposes.

model.1.data <-
  standata(brm(data = mod,
      family = categorical(link = logit, refcat = NA),
      bf(trptrans ~ 1,
         nlf(mu0 ~ btc * (a1 + b1tt * travelTimeDrive + travelCostDrive)),
         nlf(mu1 ~ btc * (a2 + b2tt * travelTimeWalk)),
         nlf(mu2 ~ btc * (a3 + b3tt * travelTimeBike)),
         nlf(mu3 ~ btc * (b4tt * travelTimeTransit)),
         mvbind(btc + a1 + a2 + a3 + b1tt + b2tt + b3tt + b4tt) ~ 1 + (1 + race_2 + race_3 + race_4 + race_5 + hhfaminc|personid) + (1|trippurp)),
      prior = c(prior(normal(0, 2.5), class = b, nlpar = a1),
                prior(normal(0, 2.5), class = b, nlpar = a2),
                prior(normal(0, 2.5), class = b, nlpar = a3),
                prior(normal(0, 2.5), class = b, nlpar = b1tt),
                prior(normal(0, 2.5), class = b, nlpar = btc),
                prior(normal(0, 2.5), class = b, nlpar = b2tt),
                prior(normal(0, 2.5), class = b, nlpar = b3tt),
                prior(normal(0, 2.5), class = b, nlpar = b4tt)),
      empty=T,
      backend = "cmdstanr",
      threads = threading(4)))

model.1 = cmdstan_model("equity_brms.stan", cpp_options = list(stan_threads = TRUE))

model_fit = model.1$sample(data = model.1.data,
			seed = 24567,
			iter_warmup  = 1000,
			iter_sampling =1000,
  			chains = 4,
  			parallel_chains = 4,
			  threads_per_chain = 4)