In an attempt to better understand centred vs. non-centred parametrisation, I implemented a simple hierarchical model where I estimate group-level means of samples from a lognormal distribution.
Models
In centred parametrisation the model reads
with priors
In non-centred parametrisation the model is
with the same priors as in the centred parametrisation model.
Question
I am trying to understand why the model based on non-centred parametrisation leads to divergent transitions, whereas the centred parametrisation model does not. Parameter estimates from both parametrisations are the same.
Is this something to worry about? I was under the impression that non-centred parametrisation in hierarchical models may help avoid divergent transition. Here, non-centred parametrisation leads to divergent transition not present in centred parametrisation.
Reproducible code
Data
We draw samples from a lognormal distribution. In every group, we have a different number of samples, as well as different location parameter.
# Data
N <- c(5, 10, 20)         # Number of samples per group
mu <- c(1, 5, 2.7)        # Location parameter per group
set.seed(2020)
grp <- rep(seq_along(N), times = N)
y <- unlist(Map(function(n, meanlog) rlnorm(n, meanlog), N, mu))
stan_data <- list(N = sum(N), J = length(N), y = y, grp = grp)
Model 1 - Centred parametrisation
Define the Stan model and store in a file
model_code <- "
data {
  int<lower=1> N;
  int<lower=1> J;
  vector<lower=0>[N] y;
  int<lower=1,upper=J> grp[N];
}
parameters {
  vector[J] theta;
  real<lower=0> sigma;
  // Hyperparameters
  real mu_theta;
  real<lower=0> sigma_theta;
}
model {
  // Partial pooling
  theta ~ normal(mu_theta, sigma_theta);
  sigma ~ cauchy(0, 2.5);
  // Priors on the Hyperparameters
  mu_theta ~ normal(mean(log(y)), 5);
  sigma_theta ~ cauchy(0, 2.5);
  for (i in 1:N) {
    y[i] ~ lognormal(theta[grp[i]], sigma);
  }
}
"
con <- file("cp_lognormal.stan")
writeLines(model_code, con)
Fit the model in RStan
library(rstan)
mod1 <- stan_model("cp_lognormal.stan")
fit1 <- sampling(object = mod1, data = stan_data, seed = 2020)
summary(fit1)$summary
#                   mean     se_mean        sd        2.5%         25%
#theta[1]      0.2855719 0.008559706 0.5751184  -0.8627703  -0.1015152
#theta[2]      5.3411934 0.006657086 0.4006931   4.5501807   5.0812342
#theta[3]      3.0614557 0.004367888 0.2888357   2.4904444   2.8674964
#sigma         1.2734530 0.002633512 0.1657430   0.9968682   1.1525125
#mu_theta      2.9793557 0.032845289 1.7468227  -0.6286331   1.9852084
#sigma_theta   3.2568763 0.040588213 1.9494459   1.2781818   2.0260763
#lp__        -30.0883335 0.044299687 1.8676741 -34.4225582 -31.1043682
#                    50%        75%      97.5%    n_eff      Rhat
#theta[1]      0.2798689   0.655243   1.453193 4514.366 0.9997062
#theta[2]      5.3461704   5.608261   6.144940 3622.893 1.0004306
#theta[3]      3.0618461   3.252180   3.631591 4372.793 1.0007773
#sigma         1.2575959   1.377929   1.637246 3960.955 0.9995801
#mu_theta      2.9571910   3.947534   6.629189 2828.469 1.0000863
#sigma_theta   2.7328053   3.870513   8.219941 2306.867 0.9997357
#lp__        -29.7625227 -28.707047 -27.478843 1777.464 1.0003966
Model 2 - Non-centred parametrisation
Define the Stan model and store in a file
model_code <- "
data {
  int<lower=1> N;
  int<lower=1> J;
  vector<lower=0>[N] y;
  int<lower=1,upper=J> grp[N];
}
parameters {
  vector[J] theta_raw;
  real<lower=0> sigma;
  // Hyperparameters
  real mu_theta;
  real<lower=0> sigma_theta;
}
transformed parameters {
  vector[J] theta;
  // Non-centred parametrisation
  // This is the same as theta ~ normal(mu_d, sigma_d)
  for (j in 1:J) {
    theta = mu_theta + sigma_theta * theta_raw;
  }
}
model {
  // Prior on non-centred theta and sigma
  theta_raw ~ std_normal();
  sigma ~ cauchy(0, 2.5);
  // Priors on the Hyperparameters
  mu_theta ~ normal(mean(log(y)), 5);
  sigma_theta ~ cauchy(0, 2.5);
  for (i in 1:N) {
    y[i] ~ lognormal(theta[grp[i]], sigma);
  }
}
"
con <- file("ncp_lognormal.stan")
writeLines(model_code, con)
Fit the model in RStan
library(rstan)
mod2 <- stan_model("ncp_lognormal.stan")
fit2 <- sampling(object = mod2, data = stan_data, seed = 2020)
#Warning messages:
#1: There were 12 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
#http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
#2: Examine the pairs() plot to diagnose sampling problems
summary(fit2)$summary
```r
#                     mean     se_mean        sd        2.5%         25%
#theta_raw[1]  -1.00815665 0.021350370 0.6672576  -2.3856518  -1.4463231
#theta_raw[2]   0.93633840 0.020740896 0.6737323  -0.2997736   0.4682227
#theta_raw[3]   0.06121563 0.017203945 0.5451971  -1.0072525  -0.3035223
#sigma          1.27923073 0.003732861 0.1653593   0.9983752   1.1625441
#mu_theta       2.93517303 0.059544351 1.7298036  -0.5358579   1.9131715
#sigma_theta    3.15949936 0.052453351 1.5961028   1.2744281   2.0412726
#theta[1]       0.26959400 0.009110810 0.5840765  -0.8713866  -0.1143934
#theta[2]       5.33212480 0.006373734 0.4052442   4.5246934   5.0646429
#theta[3]       3.06909945 0.004577154 0.2876905   2.4925307   2.8755302
#lp__         -26.93885437 0.054902828 1.8838247 -31.5670625 -27.9909609
#                      50%         75%       97.5%     n_eff      Rhat
#theta_raw[1]  -0.96611338  -0.5365614   0.1931393  976.7340 1.0064084
#theta_raw[2]   0.89889151   1.3947660   2.3316565 1055.1635 1.0034082
#theta_raw[3]   0.06335142   0.4242516   1.1427822 1004.2710 1.0024077
#sigma          1.26518435   1.3844068   1.6494535 1962.3369 1.0038053
#mu_theta       2.89729905   3.9290596   6.5242073  843.9417 1.0027692
#sigma_theta    2.71693128   3.8349804   7.6267323  925.9238 1.0071633
#theta[1]       0.25730853   0.6466723   1.4239908 4109.8458 1.0006812
#theta[2]       5.33603946   5.5971617   6.1530808 4042.4589 1.0010087
#theta[3]       3.06805385   3.2614220   3.6325209 3950.5722 0.9994564
#lp__         -26.59301849 -25.5633259 -24.2447598 1177.3120 1.0043606
I get 12 divergent transitions with the seed specified above.
Pairs plots
For the centred parametrisation model
pairs(fit1, pars = c("theta", "mu_theta", "sigma_theta"))
For the non-centred parametrisation model
pairs(fit2, pars = c("theta", "mu_theta", "sigma_theta"))
Model in brms
Interestingly, when I fit the model in brms I also end up with divergent transitions
library(brms)
fit3 <- brm(
    y ~  1 | grp,
    family = lognormal(),
    data = data.frame(y = y, grp = grp),
    seed = 2020)
#Warning messages:
#1: There were 14 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
#http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
#2: Examine the pairs() plot to diagnose sampling problems
Group-level estimates agree with those from the rstan models.
fixef(fit3)[, "Estimate"] + ranef(fit3)$grp[, "Estimate", 1]
#        1         2         3
#0.2416648 5.3570510 3.0643358
I remember reading somewhere on the discourse that brms may already use non-centred parametrisation by default.






