Brms sampling in multi-level multinomial models

Hi all,

I am running a multi-level multinomial model, but having some issues fitting it efficiently with brms. First off, I am aware fitting Bayesian models can be computationally heavy and it requires time, but my models run more than 30 days without finishing and I am afraid I have overlooked something important.

The model
I want to model the mapping between seven decorrelated continuous variables and 7 categories. Moreover, I want to describe how this mapping differs between demographic variables (like sex, country or language). I am using the following intercept-free model:

model <- brm(
  category ~ 0 + RC1 + RC2 + RC3 + RC4 + RC5 + RC6 + RC7 + (0 + RC1 + RC2 + RC3 + RC4 + RC5 + RC6 + RC7 | sex + country:language + speaker),
  data = dat_list,
  family = categorical(link = "logit"),
  prior = set_prior("normal(0,1)", class = "b"),
  control = list(adapt_delta = 0.999, max_treedepth = 15))
)

The full dataset consists of 52k samples. Most group-level effects contain a few levels, but one factor contains 432 levels (speakers).

Observations
Somehow the model does not scale well:

  • The full model as specified above runs in reasonable time and without warnings on a subset of the data
  • An intercept-only model on all data also runs in reasonable time and without warnings (RC1 + RC2 + RC3 + RC4 + RC5 + RC6 + RC7 + (1 | sex) + (1 | country:language) + (1 | speaker))
  • A population-level model only runs through very fast on the full dataset (RC1 + RC2 + RC3 + RC4 + RC5 + RC6 + RC7)
  • But sampling from the full dataset with the full model does not work well (often it does not even reach the 10% marker)

I played around with a divide and conquer strategy (only taking into account the population effects, came across the idea here), splitting the data up in equal bins and running smaller models on the subgroups and later average over the posteriors. The result was that the estimates were often off and the estimated distribution was too wide. Also, I am not sure if such strategy is advisable given we use a hierarchical model, pooling across groups in all data and not just subsets, so I rejected this approach.

A simple way to speedup the sampling is to simply run on more CPUs at once (moved from 4 to 32 CPUs), use one chain per CPU and obtain less posterior samples per chain.

Another way to increase sampling speed, is to experiment with the number of warmup iterations, to keep this phase as short as possible while having reliable estimates. Although I noted that when using this many chains the chains took a different time to complete, which shouldn’t really happen, right?

Finally, using less variables would probably speedup the model too, but this conflicts with my scientific interest.

Work in progress
Currently, I am looking into improving the grain size and investigating if it is possible to profile the execution of the model, to investigate what is going wrong.

Open questions
It might be wise to specify better priors, but I have a hard time coming up with a good prior. I picked a standard normal prior, because the variables are centered around 0 and roughly fall into the range of a standard normal across all categories. I would be interested if you could suggest a prior with better sampling properties.

Finally, I would be interested if you have some suggestions how to improve sampling speed and if I overlooked some possibilities.

Thank you for taking the time to read this topic!

Hi!

Are you able to share the Stan code generated by brms? You may be able to get some help optimizing the code, which will in turn speed up sampling.

I received some fantastic help in this thread that took days off of sampling time.

1 Like

@polvanrijn follow the suggestions in the thread that @JLC linked, and also:

  1. Vectorize every for loop in your code
  2. You can apply within-chain parallelization in brms. See this vignette
  3. Tell brms to perform a QR decomposition on the data (this helps to speed things up in Hierarchical Models): bf(formula, family = XX, decomp = "QR")
  4. Try using CdmStanR instead of RStan (for me sometimes I see a ~3x speedup): brm(formula, data, backend = "cmdstanr").
  5. You can tell brms to use log unormalized PDF (*_lupdf, *_lpdm) instead of the log normalized PDF (*_lpdf) and *_lpmf). This is available with the latest cmdstanr and the GitHub version of brms: brm(..., backend = "cmdstanr", normalize = FALSE). See this(issue and this pull request.

Those 5 things may speedup things. Also, if you can, please provide us with the code that brms generated, so people can help to do more speedup.

1 Like