Priors in cmdstan not working properly (brms categorical model)

mydata.txt (193.9 KB)
mypriors.txt (592 Bytes)

I’ve been super impressed with the speedups offered by cmdstan, as well as the fact that it hasn’t crashed on me even once…

But unfortunately, there seems to be something wrong with the way priors work in cmdstan. Basically, they’re not shrinking the posterior means the way they should. My anonymized dataset (attached) has a quaternary categorical response y, and 28 predictors x_{1} \dots x_{28}. Most predictors have N(0, 2.5) priors (see mypriors.txt, attached). The fixed-effect parameters muB_x19B and muC_x24E have maximum-likelihood estimates of negative infinity, because predictors x19B and x24E have sampling zeroes for response categories B and C, respectively. Hence, a frequentist model “converges”, after a large number of iterations, to logits of -14.45 and -10.5 for those two parameters:

TestMod.freq <- nnet::multinom(y ~ ., data = mydata, maxit = 5000) 

But the likelihood is quite flat at these estimates because the associated predictor values have few observations. Hence, the N(0, 2.5) prior should shrink the posterior mean to a reasonable single-digit value. And in regular rstan, it indeed does:

TestMod.stan <- brm(y ~ ., family = categorical, prior = mypriors, data = mydata,
               chains = 2, cores = 2, warmup = 1000, iter = 6000, seed = 2022, control = list(adapt_delta = 0.90))

^ The posteriors of muB_x19B and muC_24E are now centered around -2, just with wider spreads than the other parameters. The priors have done their job beautifully.

BUT, look what happens when I try the same using cmdstanr:

TestMod.cmdstan <- brm(bf(y ~ ., decomp = "QR"), family = categorical, prior = mypriors, data = mydata,
  threads = threading(2), chains = 2,  cores = 4, warmup = 1000, iter = 3500, seed = 2022, control = list(adapt_delta = 0.90), backend = "cmdstanr")

^ The priors are no longer working as expected (if at all). The posterior means are even farther from zero than in the frequentist model. It is obvious that the N(0, 2.5) prior is imposing very little, if indeed any, shrinkage. This is despite the fact that the priors seem to have been “read” correctly. Calling prior_summary(TestMod.cmdstan) displays the correct N(0, 2.5) priors on the unruly parameters.

There just seems to be something wrong with how cmdstan interprets, or applies, those priors.

This is very unfortunate because I have been otherwise enamored with the speed and stability improvements offered by cmdstan over rstan. Is any quick fix possible?

Are you intentionally using a different formula between the two calls? I believe a QR decomposition changes the interpretation of the priors

Thanks! I have no idea what QR decomposition does, it was just advised in this thread that it could speed up sampling in hierarchical models.

The model in my example is not hierarchical, so disabling QR decomposition solves the problem. But, I will soon have to start fitting hierarchical models to this same data, with 3 to 6 group-level SDs. Does this mean that I will run into the same problem with those hierarchical if I use QR decomposition? Or is this a problem that will only arise when using QR decomposition with non-hierarchical models?

I’m afraid that is a question I’m not too qualified to answer. The Stan users guide provides a bit more information 1.2 The QR reparameterization | Stan User’s Guide

I should note that this question is separate from the use of backend I believe (QR decomposition in the RStan example should have the same behavior)

1 Like

I now observe that the same problem does indeed happen when I use QR decomposition with a hierarchical model: the N(0, 2.5) priors do not shrink the infinite logits appropriately. So I’ve had to abandon QR decomposition, and don’t know how to reap its advertised benefits without having my priors cease to work.

Fortunately though, in my case most of the speed improvements conferred by cmdstan seem to be due to other things than QR decomposition (e.g. within-chain parallelization). So even without QR decomposition, I’m enjoying a reduction in fitting times.