Different families for mean and sigma in brms distributional models

Hello,

First time user of both brms and stan here.

I am trying to fit hierarchical models of protein binding affinities in brms. The models have forms similar to

affinity ~ 1 + (1|protein) + (1|protein:organism)

so that they can be used for out-of-sample prediction. The dataset contains 30k+ combinations of protein:organism, each with only few measurements (approx 1 to 10). Using default settings, a gaussian family and reasonable priors the models fit smoothly. No warnings are shown and Rhat/ESS look healthy.

I then noticed that different groups have different residuals, so I decided to fit sigma as well (Estimating Distributional Models with brms (r-project.org)) to capture the uncertainty more precisely.

The updated model looks like this:

model <- brm(
  bf(
    affinity ~ 1 + (1|protein) + (1|protein:organism),
    sigma ~ 1 + (1|protein) + (1|protein:organism),  # note: in brms this is actually the log of sigma
    decomp = "QR"
  ),
  df, prior =  c(...), ...
)

Suddenly fitting the model becomes much harder. I now get ~10% divergent transitions, ~90% max threedepth hits and moderately high Rhat (~1.5). Setting adapt_delta = 0.99 and increasing the number of steps reduces divergent transitions to 2 and Rhat to 1.15. I don’t doubt that pushing the parameters further I could eventually get a decent fit, but the runtime of the fit is already at 2 days with the current parameters, so this option is not ideal.

Digging through the forum and the stan docs it really looks like there is an issue in my model (or its parametrization). I tried to investigate divergent transitions to see if this is a funnel problem (Diagnosing Biased Inference with Divergences (mc-stan.org)) but I couldn’t spot anything interesting. So I checked whether the gaussian family is actually a good choice using QQ plots:

means <- fitted(model)
sigmas <- fitted(model, dpar = "sigma")

qqnorm(means[,1], main="QQ plot means")
qqline(means[,1])
qqnorm(log(sigmas[,1]), main="QQ plot log-sigmas")
qqline(log(sigmas[,1]))
hist(log(sigmas[,1]), main="log-sigmas distribution")

Now this looks like a problem: my data is normally distributed as I assumed, but the log-sigma are not. It rather looks like a student distribution.

And here finally come my questions:

  1. Can a poorly chosen family have such a strong impact on the fitting difficulty/time?
  2. If yes: I could fix this by using a student family for the sigmas (while keeping a gaussian for the means) but I could not find any information on how use multiple families in brms. Is this possible?
  3. If not: what other paths could I take to improve fitting performance? Should I consider a different parametrization of the problem? (how?)

Thank you!
Mattia

Hi Mattia, welcome to The Stan Forums!

As far as I understand brms’s distributional models, the predicted distributional parameters themselves have no distribution in the model. There is only a distribution for the response values, given the distributional parameters. You can see this in the Stan code, e.g., for the following example from the brms vignette “Estimating Distributional Models with brms”:

set.seed(69238)
group <- rep(c("treat", "placebo"), each = 30)
symptom_post <- c(rnorm(30, mean = 1, sd = 2), rnorm(30, mean = 0, sd = 1))
dat1 <- data.frame(group, symptom_post)

library(brms)
options(mc.cores = parallel::detectCores(logical = FALSE))
fit1 <- brm(bf(symptom_post ~ group, sigma ~ group),
            data = dat1, family = gaussian(),
            seed = 45732)
stancode(fit1)

which gives

// generated with brms 2.17.0
functions {
}
data {
  int<lower=1> N;  // total number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  int<lower=1> K_sigma;  // number of population-level effects
  matrix[N, K_sigma] X_sigma;  // population-level design matrix
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  int Kc_sigma = K_sigma - 1;
  matrix[N, Kc_sigma] Xc_sigma;  // centered version of X_sigma without an intercept
  vector[Kc_sigma] means_X_sigma;  // column means of X_sigma before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
  for (i in 2:K_sigma) {
    means_X_sigma[i - 1] = mean(X_sigma[, i]);
    Xc_sigma[, i - 1] = X_sigma[, i] - means_X_sigma[i - 1];
  }
}
parameters {
  vector[Kc] b;  // population-level effects
  real Intercept;  // temporary intercept for centered predictors
  vector[Kc_sigma] b_sigma;  // population-level effects
  real Intercept_sigma;  // temporary intercept for centered predictors
}
transformed parameters {
  real lprior = 0;  // prior contributions to the log posterior
  lprior += student_t_lpdf(Intercept | 3, 0.2, 2.5);
  lprior += student_t_lpdf(Intercept_sigma | 3, 0, 2.5);
}
model {
  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu = Intercept + Xc * b;
    // initialize linear predictor term
    vector[N] sigma = Intercept_sigma + Xc_sigma * b_sigma;
    for (n in 1:N) {
      // apply the inverse link function
      sigma[n] = exp(sigma[n]);
    }
    target += normal_lpdf(Y | mu, sigma);
  }
  // priors including constants
  target += lprior;
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = Intercept - dot_product(means_X, b);
  // actual population-level intercept
  real b_sigma_Intercept = Intercept_sigma - dot_product(means_X_sigma, b_sigma);
}

So I guess this answers your questions 1 and 2 from

For the third question, my best (quick) guess is that this model is just overly flexible and I’m not sure if there is way to reparameterize this so that you have less convergence issues. Your reason for the more complex model was

but are these differences in the residuals really so bad that you need the more complex model? Have you performed posterior predictive checks (PPCs)?

Thank you Frank!

Right, mean and sigma are parameters of the gaussian family, but they don’t have a family themselves. Also my plots above could be misleading, as they represent the posterior predictions and not the original data.

I made some crude plots that show the distribution mean and sd in each group (showing only groups with n >= 3, but the plots look similar including all groups as well)
image image

These are simply df %>% group_by(protein, organism) %>% summarise(s=sd(affinity), m=mean(affinity)). The means of each group look normally distributed, while the log-sigmas look skew-normal.

I think what could help is changing is the distribution of the random effects of sigma (gr(..., dist=<..)), but at the moment only gaussian and student are available, no skew-normal.

but are these differences in the residuals really so bad that you need the more complex model? Have you performed posterior predictive checks (PPCs)?

Looking at the plot above I believe so, and this also makes sense: some affinities are nicely conserved, most of them have larger differences because of different experimental conditions (which are poorly annotated), and there is a good number of cases where values/annotations are mixed up. I forgot to mention that I model affinities on the log scale, so we are talking about orders of magnitude differences. For predictive applications a single sigma parameter is a nightmare, since the misannotated entries would push the value up and the uncertainty of well-characterized entries would be overestimated.

I tried PPCs as you suggested, comparing the simple (left) and the distributional (right) models. I’m new to PPCs but here is what I see:
pp_check(..., type = "error_hist")
image image

The distributional model seem to improve the fit a lot. Are there other checks you would suggest?

For the comparison I used my current best fit for the distributional model, which has Rhat < 1.05 and ~0.5% divergent transitions. As mentioned, I struggle bringing them to zero. There seem to be a consensus on the forum/stan docs that even a single divergence is bad. However, if I had the choice between a simple model which is easy to fit but poorly represent the data and a distributional model that is potentially biased but clearly performs better in PPCs, wouldn’t it be better to take the second one?

To me, that sounds like your data is lacking important aspects (the experimental conditions), so I’m not sure if you will be able to solve the problem by tweaking the model, given that the data might already be insufficient.

It seems that your plots only correspond to a single posterior draw per model. I often find type = "dens_overlay" or type = "ecdf_overlay" helpful, in your case perhaps also type = "dens_overlay_grouped" or type = "ecdf_overlay_grouped". With those types, it’s easier to inspect multiple posterior draws at once.

Furthermore, since you mentioned the predictive point of view at several occasions, you could try to compare the two models in terms of predictive performance (see the loo package and the corresponding methods like loo.brmsfit() in brms). It’s possible that the distributional model overfits to the observed data and that its out-of-sample predictive performance is not that good.

Also, have you tried the model

model <- brm(
  bf(
    affinity ~ 1 + (1|protein) + (1|protein:organism),
    sigma ~ 1 + (1|protein),  # note: in brms this is actually the log of sigma
    decomp = "QR"
  ),
  df, prior =  c(...), ...
)

i.e., leaving out the term (1|protein:organism) for sigma? However, I don’t know if this model makes sense from your expert-knowledge biological perspective.

Hi Frank,

Simplifying the model for sigma as you suggested helps. I think I tried this at the beginning and then ruled it out, but I cannot remember why. Anyway, now it looks better!

I think I still have problems with the sparsity of data, but that’s not really a stan issue. I will try to find more data sources to integrate in the model.

Thanks for your help!

1 Like