Centred vs. non-centred parametrisation with lognormal likelihood

In an attempt to better understand centred vs. non-centred parametrisation, I implemented a simple hierarchical model where I estimate group-level means of samples from a lognormal distribution.

Models

In centred parametrisation the model reads

\begin{aligned} y_i &\sim \textrm{LogNormal}(\theta_{j[i]}, \sigma^2)\,,\\ \theta_j &\sim \textrm{Normal}(\mu_\theta, \sigma_\theta^2) \end{aligned}

with priors

\begin{aligned} \sigma &\sim \textrm{Cauchy}(0, 2.5)\,,\\ \mu_\theta &\sim \textrm{Normal}(\text{mean}(\log{y}), 5)\,,\\ \sigma_\theta &\sim \textrm{Cauchy}(0, 2.5)\,. \end{aligned}

In non-centred parametrisation the model is

\begin{aligned} y_i &\sim \textrm{LogNormal}(\theta_{j[i]}, \sigma^2)\,,\\ \theta_{\text{raw}, j} &\sim \textrm{Normal}(0, 1)\\ \theta_j &= \mu_\theta + \sigma_\theta \theta_{\text{raw}, j} \end{aligned}

with the same priors as in the centred parametrisation model.

Question

I am trying to understand why the model based on non-centred parametrisation leads to divergent transitions, whereas the centred parametrisation model does not. Parameter estimates from both parametrisations are the same.

Is this something to worry about? I was under the impression that non-centred parametrisation in hierarchical models may help avoid divergent transition. Here, non-centred parametrisation leads to divergent transition not present in centred parametrisation.

Reproducible code

Data

We draw samples from a lognormal distribution. In every group, we have a different number of samples, as well as different location parameter.

# Data
N <- c(5, 10, 20)         # Number of samples per group
mu <- c(1, 5, 2.7)        # Location parameter per group
set.seed(2020)
grp <- rep(seq_along(N), times = N)
y <- unlist(Map(function(n, meanlog) rlnorm(n, meanlog), N, mu))
stan_data <- list(N = sum(N), J = length(N), y = y, grp = grp)

Model 1 - Centred parametrisation

Define the Stan model and store in a file

model_code <- "
data {
  int<lower=1> N;
  int<lower=1> J;
  vector<lower=0>[N] y;
  int<lower=1,upper=J> grp[N];
}

parameters {
  vector[J] theta;
  real<lower=0> sigma;

  // Hyperparameters
  real mu_theta;
  real<lower=0> sigma_theta;
}

model {
  // Partial pooling
  theta ~ normal(mu_theta, sigma_theta);
  sigma ~ cauchy(0, 2.5);

  // Priors on the Hyperparameters
  mu_theta ~ normal(mean(log(y)), 5);
  sigma_theta ~ cauchy(0, 2.5);

  for (i in 1:N) {
    y[i] ~ lognormal(theta[grp[i]], sigma);
  }
}
"
con <- file("cp_lognormal.stan")
writeLines(model_code, con)

Fit the model in RStan

library(rstan)
mod1 <- stan_model("cp_lognormal.stan")
fit1 <- sampling(object = mod1, data = stan_data, seed = 2020)
summary(fit1)$summary
#                   mean     se_mean        sd        2.5%         25%
#theta[1]      0.2855719 0.008559706 0.5751184  -0.8627703  -0.1015152
#theta[2]      5.3411934 0.006657086 0.4006931   4.5501807   5.0812342
#theta[3]      3.0614557 0.004367888 0.2888357   2.4904444   2.8674964
#sigma         1.2734530 0.002633512 0.1657430   0.9968682   1.1525125
#mu_theta      2.9793557 0.032845289 1.7468227  -0.6286331   1.9852084
#sigma_theta   3.2568763 0.040588213 1.9494459   1.2781818   2.0260763
#lp__        -30.0883335 0.044299687 1.8676741 -34.4225582 -31.1043682
#                    50%        75%      97.5%    n_eff      Rhat
#theta[1]      0.2798689   0.655243   1.453193 4514.366 0.9997062
#theta[2]      5.3461704   5.608261   6.144940 3622.893 1.0004306
#theta[3]      3.0618461   3.252180   3.631591 4372.793 1.0007773
#sigma         1.2575959   1.377929   1.637246 3960.955 0.9995801
#mu_theta      2.9571910   3.947534   6.629189 2828.469 1.0000863
#sigma_theta   2.7328053   3.870513   8.219941 2306.867 0.9997357
#lp__        -29.7625227 -28.707047 -27.478843 1777.464 1.0003966

Model 2 - Non-centred parametrisation

Define the Stan model and store in a file

model_code <- "
data {
  int<lower=1> N;
  int<lower=1> J;
  vector<lower=0>[N] y;
  int<lower=1,upper=J> grp[N];
}

parameters {
  vector[J] theta_raw;
  real<lower=0> sigma;

  // Hyperparameters
  real mu_theta;
  real<lower=0> sigma_theta;
}

transformed parameters {
  vector[J] theta;

  // Non-centred parametrisation
  // This is the same as theta ~ normal(mu_d, sigma_d)
  for (j in 1:J) {
    theta = mu_theta + sigma_theta * theta_raw;
  }
}

model {
  // Prior on non-centred theta and sigma
  theta_raw ~ std_normal();
  sigma ~ cauchy(0, 2.5);

  // Priors on the Hyperparameters
  mu_theta ~ normal(mean(log(y)), 5);
  sigma_theta ~ cauchy(0, 2.5);

  for (i in 1:N) {
    y[i] ~ lognormal(theta[grp[i]], sigma);
  }
}
"
con <- file("ncp_lognormal.stan")
writeLines(model_code, con)

Fit the model in RStan

library(rstan)
mod2 <- stan_model("ncp_lognormal.stan")
fit2 <- sampling(object = mod2, data = stan_data, seed = 2020)
#Warning messages:
#1: There were 12 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
#http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
#2: Examine the pairs() plot to diagnose sampling problems
summary(fit2)$summary
```r
#                     mean     se_mean        sd        2.5%         25%
#theta_raw[1]  -1.00815665 0.021350370 0.6672576  -2.3856518  -1.4463231
#theta_raw[2]   0.93633840 0.020740896 0.6737323  -0.2997736   0.4682227
#theta_raw[3]   0.06121563 0.017203945 0.5451971  -1.0072525  -0.3035223
#sigma          1.27923073 0.003732861 0.1653593   0.9983752   1.1625441
#mu_theta       2.93517303 0.059544351 1.7298036  -0.5358579   1.9131715
#sigma_theta    3.15949936 0.052453351 1.5961028   1.2744281   2.0412726
#theta[1]       0.26959400 0.009110810 0.5840765  -0.8713866  -0.1143934
#theta[2]       5.33212480 0.006373734 0.4052442   4.5246934   5.0646429
#theta[3]       3.06909945 0.004577154 0.2876905   2.4925307   2.8755302
#lp__         -26.93885437 0.054902828 1.8838247 -31.5670625 -27.9909609
#                      50%         75%       97.5%     n_eff      Rhat
#theta_raw[1]  -0.96611338  -0.5365614   0.1931393  976.7340 1.0064084
#theta_raw[2]   0.89889151   1.3947660   2.3316565 1055.1635 1.0034082
#theta_raw[3]   0.06335142   0.4242516   1.1427822 1004.2710 1.0024077
#sigma          1.26518435   1.3844068   1.6494535 1962.3369 1.0038053
#mu_theta       2.89729905   3.9290596   6.5242073  843.9417 1.0027692
#sigma_theta    2.71693128   3.8349804   7.6267323  925.9238 1.0071633
#theta[1]       0.25730853   0.6466723   1.4239908 4109.8458 1.0006812
#theta[2]       5.33603946   5.5971617   6.1530808 4042.4589 1.0010087
#theta[3]       3.06805385   3.2614220   3.6325209 3950.5722 0.9994564
#lp__         -26.59301849 -25.5633259 -24.2447598 1177.3120 1.0043606

I get 12 divergent transitions with the seed specified above.

Pairs plots

For the centred parametrisation model

pairs(fit1, pars = c("theta", "mu_theta", "sigma_theta"))

For the non-centred parametrisation model

pairs(fit2, pars = c("theta", "mu_theta", "sigma_theta"))

Model in brms

Interestingly, when I fit the model in brms I also end up with divergent transitions

library(brms)
fit3 <- brm(
    y ~  1 | grp,
    family = lognormal(),
    data = data.frame(y = y, grp = grp),
    seed = 2020)
#Warning messages:
#1: There were 14 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
#http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
#2: Examine the pairs() plot to diagnose sampling problems

Group-level estimates agree with those from the rstan models.

fixef(fit3)[, "Estimate"] + ranef(fit3)$grp[, "Estimate", 1]
#        1         2         3
#0.2416648 5.3570510 3.0643358

I remember reading somewhere on the discourse that brms may already use non-centred parametrisation by default.

3 Likes

Looks right. This can happen for sure.

Could you plot log(sigma_theta) in the pair plots instead of sigma? The sampler itself is moving around in unconstrained space and cause of the constraint on sigma_theta it’ll be sampling on log(sigma_theta).

Thanks for the response @bbbales2. Here are the pairs plots with log(sigma_theta)

Centred parametrisation

Non-centred parametrisation

Here’s my summary:

  1. Divergent transitions are not clustering in any particular area in sample space, suggesting no major issues
  2. There is a clear funnel in the mu_theta vs. log_sigma_theta plots; the funnel is broader/more diffuse in non-centred parametrisation.

So am I correct in

  1. ignoring the divergent transitions (they go away if I increase adapt_delta),
  2. ignoring the funnel in the mu_theta vs. log_sigma_theta plots?

More fundamentally: Is there a/any advantage here for using non-centred parametrisation?

Can you plot the raw parameters? theta_raw, sigma, mu_theta, and sigma_theta?

As a separate way to get rid of divergences, could you try replacing the Cauchy priors with normals?

Using normal priors on the sigma’s

sigma ~ normal(0, 5);
sigma_theta ~ normal(0, 5);

still gives divergent transitions in the non-centred parametrisation.

As for the raw parameters:

Non-centred parametrisation with Cauchy priors on the sigma’s

Non-centred parametrisation with normal priors on the sigma’s

Not sure what to make of that.

Oh there we go. When the data is super informative there’s like a non-identifiability between the group mean and the random effects, something like that. Play around with it. Do less data and the non-centered will be better probably.

The non-centered isn’t always better. (Edit: which is the conclusion you came to in your original post!)

1 Like

Sadly I’m still missing my own “there we go” moment :-(. Could you explain what you mean by “non-identifiability between the group mean and the random effects”? The sample data shouldn’t be very informative, I had thought. Sample sizes are small to moderate; and differences in the means are also not that large.

I struggle to interpret the structure in the pairs plots involving the raw parameters. Wouldn’t the non-Gaussian blob nature be a cause for concern?

Yeah that’s what happens.

I’m off to bed so I won’t be responding for a while, but you can get this to happen with eight schools (Diagnosing Biased Inference with Divergences).

Take the data:

data = list(N = 8,
            y = c(28,  8, -3,  7, -1,  1, 18, 12),
            sigma = c(15, 10, 16, 11, 9, 11, 10, 18))

And you’ll find centered better than non-centered.

And then take the data:

data = list(N = 8,
            y = c(28,  8, -3,  7, -1,  1, 18, 12),
            sigma = c(0.15, 0.10, 0.16, 0.11, 0.09, 0.11, 0.10, 0.18))

And centered will work better than non-centered. The difference is that the measurement errors are really low now (and the school estimates aren’t really overlapping anymore).

I don’t have a good explanation for what’s happening, but you’ll see it there too. I suspect if you do less draws per group you’ll see it in your example :D.

1 Like

Oh yeah the non-Gaussian blob is just hard for the sampler to move around in. The Gaussian blob is about the easiest thing. Whenever there’s big correlations or changing curvatures in a posterior it can cause problems. This has it all!

And here I thought I had a simple example: lognormal data with group-level means and unit sigma parameter.

I see what you mean with the 8 schools data. However, in my case, there is considerable overlap between the different groups:

ggplot(data.frame(y = log(stan_data$y), grp = factor(stan_data$grp)), aes(y, fill = grp)) +
    geom_histogram(binwidth = 1, position = "dodge2")

So this should’ve corresponded to the original 8 schools data, where non-centred does better than centred. No?

I am just going to respond so Ben can get some sleep.

Well, in your example group 1 and 2 don’t overlap at all. The thing to remember is that the decision to center or not-center is not a model level decision but a parameter level decision. That is, you could do one parameterization for some groups and another for the other groups and that could be the optimal approach. It would be unwieldy obviously. I only realized this after this comment by Michael Betancourt. If you really want to get into the details of when to use which parametrization Michael’s paper with Mark Girolami is my go to guide.

2 Likes

Just for reference I just made up the overlap thing without thinking much. I just know you can get the non-centering vs. centering thing by making y large instead of sigma small, and that part of the motivation for the original 8-schools problem is that somehow it’s not even clear for any school individually if students were on average scoring better or worse with or without the intervention.

But then I got confused thinking if there’s anything special about zero or not lol, so I just went with overlap to avoid talking and zero-effects.

The lognormal model here is a red herring, although it will complicate matters later; the log normal location parameter has to be positive so the original model as written isn’t consistent. To ensure a consistent model one would typically exponentiate the latent normal variables which introduces all kinds of problems on its own. To simply matters for this example it’s easier to first fit

log_y[n] ~ normal(mu[group[n]], sigma)

to understand the subtleties of hierarchical models.

In a hierarchal model the individual parameters \theta_{n} are distributed according to the population parameters \mu and \tau, typically we assume latent normal model,

\theta_{n} \sim \text{normal}(\mu, \tau).

The problem with this model is that by construction the \theta_{n} are strongly coupled to \mu and \tau. When \mu moves all of the \theta_{n} go with it, when \tau gets small all of the \theta_{n} have to shrink towards \mu. In other words in this model the individual parameters are not well-identified from the population parameters. That lack of identifiability results in a joint density that fills out a “funnel” that is hard to fit.

When we introduce likelihood function informed by observations, however, things can change. If each individual has their own data then the individual likelihood functions \pi(\tilde{y}_{n} \mid \theta_{n}, \psi_{n}) can clamp down the \theta_{n} and effectively cut off the coupling with \mu and \tau. For this to happen, however, all of the likelihood functions need to be sufficiently narrow. This could mean lots of data for each individual or small observational variations, such as in the common 8-schools example (the 8-schools example is actually pretty bad exactly become is fixes \sigma which isn’t particularly representative of models common in applied practice, but let’s ignore that for now).

Fortunately we have another way to specify hierarchical models – through the relative deviations

\tilde{\theta}_{n} = \frac{\theta_{n} - \mu}{\tau}.

In the population model these relative deviations are completely uncoupled from the population parameters,

\tilde{\theta}_{n} \sim \text{normal}(0, 1),

and so are easier to fit. Likelihood functions, however, inform the \theta_{n} and not these relative deviations. In fact when the likelihood functions are narrow the relative deviations become less-identified relative to the population parameters! For example if \mu moves then the \tilde{\theta}_{n} have to move in the opposite direction to ensure that the \theta_{n} stay at the values contained by the likelihood functions.

So when an individual parameter is strongly informed by its own likelihood function then a centered parameterization is best. When it is only weakly informed then a non-centered parameterization is best. As @stijn notes consider is for each individual parameter \theta_{n}. When the width of the likelihood functions vary from individual to individual, as is common when the data are unbalanced, then you have to implement centered parameterizations for some individuals and concentered parameterizations for others.

Keep in mind that this is only the tip of the iceberg. When the likelihood functions are wide or there are only a few groups then there are fundamental degeneracies in the posteriors for \mu and \tau themselves that need to be taken into consideration (usually with informative priors). Hierarchical models are powerful but really, really subtle and they need to be implemented with care!

9 Likes

Thanks @bbbales2, @stijn and @betanalpha for the insightful answers and comments. This is great and helped clarifying some very fundamental concepts!

Just one last thing; Michael, you wrote that in

the log normal location parameter has to be positive so the original model as written isn’t consistent.

Perhaps I misunderstood, but I had always understood the log-normal to be valid for shape parameters \mu \in \mathbb{R}; the Stan documentation also seems to suggest as much. So positivity of the location parameter shouldn’t be a requirement for lognormal likelihoods. Since the location parameters of the lognormal distribution are on the log-scale, this shouldn’t be an issue. Or am I mistaken?

The reason why I’m asking is that the bigger context is a hierarchical dose response model that I’m trying to fit in RStan; some of the model’s parameters denote count-like (strictly positive) quantities which I assume to be lognormally distributed with location parameters which can be <0 (on the log-scale).

You are correct. In Stan the Lognormal distribution is actually implemented* something like

target += normal_lpdf(log(y) | mu, sigma) + log jacobian of the log transfrom

=

target += normal_lpdf(log(y) | mu, sigma) - log(y)

cf. the functions reference for Stan. Thus, \mu and \sigma refer to \log(y) and the location parameter is on the log scale. It is consitent with how the Lognormal is described on Wikipedia.

*) This is not literally the implementation, but I think you see the point I’m making. ;)

Sorry, I misspoke. The location parameter \mu is unconstrained as it refers to the location of the log argument, not the argument itself. Once you condition on the observed data the likelihood function will be the same whether you condition on y or log(y) so the posterior geometry will be equivalent whether you fit the normal model on the nominal data or the log-normal model on the transformed data.

1 Like