In an attempt to better understand centred vs. non-centred parametrisation, I implemented a simple hierarchical model where I estimate group-level means of samples from a lognormal distribution.
Models
In centred parametrisation the model reads
with priors
In non-centred parametrisation the model is
with the same priors as in the centred parametrisation model.
Question
I am trying to understand why the model based on non-centred parametrisation leads to divergent transitions, whereas the centred parametrisation model does not. Parameter estimates from both parametrisations are the same.
Is this something to worry about? I was under the impression that non-centred parametrisation in hierarchical models may help avoid divergent transition. Here, non-centred parametrisation leads to divergent transition not present in centred parametrisation.
Reproducible code
Data
We draw samples from a lognormal distribution. In every group, we have a different number of samples, as well as different location parameter.
# Data
N <- c(5, 10, 20) # Number of samples per group
mu <- c(1, 5, 2.7) # Location parameter per group
set.seed(2020)
grp <- rep(seq_along(N), times = N)
y <- unlist(Map(function(n, meanlog) rlnorm(n, meanlog), N, mu))
stan_data <- list(N = sum(N), J = length(N), y = y, grp = grp)
Model 1 - Centred parametrisation
Define the Stan model and store in a file
model_code <- "
data {
int<lower=1> N;
int<lower=1> J;
vector<lower=0>[N] y;
int<lower=1,upper=J> grp[N];
}
parameters {
vector[J] theta;
real<lower=0> sigma;
// Hyperparameters
real mu_theta;
real<lower=0> sigma_theta;
}
model {
// Partial pooling
theta ~ normal(mu_theta, sigma_theta);
sigma ~ cauchy(0, 2.5);
// Priors on the Hyperparameters
mu_theta ~ normal(mean(log(y)), 5);
sigma_theta ~ cauchy(0, 2.5);
for (i in 1:N) {
y[i] ~ lognormal(theta[grp[i]], sigma);
}
}
"
con <- file("cp_lognormal.stan")
writeLines(model_code, con)
Fit the model in RStan
library(rstan)
mod1 <- stan_model("cp_lognormal.stan")
fit1 <- sampling(object = mod1, data = stan_data, seed = 2020)
summary(fit1)$summary
# mean se_mean sd 2.5% 25%
#theta[1] 0.2855719 0.008559706 0.5751184 -0.8627703 -0.1015152
#theta[2] 5.3411934 0.006657086 0.4006931 4.5501807 5.0812342
#theta[3] 3.0614557 0.004367888 0.2888357 2.4904444 2.8674964
#sigma 1.2734530 0.002633512 0.1657430 0.9968682 1.1525125
#mu_theta 2.9793557 0.032845289 1.7468227 -0.6286331 1.9852084
#sigma_theta 3.2568763 0.040588213 1.9494459 1.2781818 2.0260763
#lp__ -30.0883335 0.044299687 1.8676741 -34.4225582 -31.1043682
# 50% 75% 97.5% n_eff Rhat
#theta[1] 0.2798689 0.655243 1.453193 4514.366 0.9997062
#theta[2] 5.3461704 5.608261 6.144940 3622.893 1.0004306
#theta[3] 3.0618461 3.252180 3.631591 4372.793 1.0007773
#sigma 1.2575959 1.377929 1.637246 3960.955 0.9995801
#mu_theta 2.9571910 3.947534 6.629189 2828.469 1.0000863
#sigma_theta 2.7328053 3.870513 8.219941 2306.867 0.9997357
#lp__ -29.7625227 -28.707047 -27.478843 1777.464 1.0003966
Model 2 - Non-centred parametrisation
Define the Stan model and store in a file
model_code <- "
data {
int<lower=1> N;
int<lower=1> J;
vector<lower=0>[N] y;
int<lower=1,upper=J> grp[N];
}
parameters {
vector[J] theta_raw;
real<lower=0> sigma;
// Hyperparameters
real mu_theta;
real<lower=0> sigma_theta;
}
transformed parameters {
vector[J] theta;
// Non-centred parametrisation
// This is the same as theta ~ normal(mu_d, sigma_d)
for (j in 1:J) {
theta = mu_theta + sigma_theta * theta_raw;
}
}
model {
// Prior on non-centred theta and sigma
theta_raw ~ std_normal();
sigma ~ cauchy(0, 2.5);
// Priors on the Hyperparameters
mu_theta ~ normal(mean(log(y)), 5);
sigma_theta ~ cauchy(0, 2.5);
for (i in 1:N) {
y[i] ~ lognormal(theta[grp[i]], sigma);
}
}
"
con <- file("ncp_lognormal.stan")
writeLines(model_code, con)
Fit the model in RStan
library(rstan)
mod2 <- stan_model("ncp_lognormal.stan")
fit2 <- sampling(object = mod2, data = stan_data, seed = 2020)
#Warning messages:
#1: There were 12 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
#http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
#2: Examine the pairs() plot to diagnose sampling problems
summary(fit2)$summary
```r
# mean se_mean sd 2.5% 25%
#theta_raw[1] -1.00815665 0.021350370 0.6672576 -2.3856518 -1.4463231
#theta_raw[2] 0.93633840 0.020740896 0.6737323 -0.2997736 0.4682227
#theta_raw[3] 0.06121563 0.017203945 0.5451971 -1.0072525 -0.3035223
#sigma 1.27923073 0.003732861 0.1653593 0.9983752 1.1625441
#mu_theta 2.93517303 0.059544351 1.7298036 -0.5358579 1.9131715
#sigma_theta 3.15949936 0.052453351 1.5961028 1.2744281 2.0412726
#theta[1] 0.26959400 0.009110810 0.5840765 -0.8713866 -0.1143934
#theta[2] 5.33212480 0.006373734 0.4052442 4.5246934 5.0646429
#theta[3] 3.06909945 0.004577154 0.2876905 2.4925307 2.8755302
#lp__ -26.93885437 0.054902828 1.8838247 -31.5670625 -27.9909609
# 50% 75% 97.5% n_eff Rhat
#theta_raw[1] -0.96611338 -0.5365614 0.1931393 976.7340 1.0064084
#theta_raw[2] 0.89889151 1.3947660 2.3316565 1055.1635 1.0034082
#theta_raw[3] 0.06335142 0.4242516 1.1427822 1004.2710 1.0024077
#sigma 1.26518435 1.3844068 1.6494535 1962.3369 1.0038053
#mu_theta 2.89729905 3.9290596 6.5242073 843.9417 1.0027692
#sigma_theta 2.71693128 3.8349804 7.6267323 925.9238 1.0071633
#theta[1] 0.25730853 0.6466723 1.4239908 4109.8458 1.0006812
#theta[2] 5.33603946 5.5971617 6.1530808 4042.4589 1.0010087
#theta[3] 3.06805385 3.2614220 3.6325209 3950.5722 0.9994564
#lp__ -26.59301849 -25.5633259 -24.2447598 1177.3120 1.0043606
I get 12 divergent transitions with the seed
specified above.
Pairs plots
For the centred parametrisation model
pairs(fit1, pars = c("theta", "mu_theta", "sigma_theta"))
For the non-centred parametrisation model
pairs(fit2, pars = c("theta", "mu_theta", "sigma_theta"))
Model in brms
Interestingly, when I fit the model in brms
I also end up with divergent transitions
library(brms)
fit3 <- brm(
y ~ 1 | grp,
family = lognormal(),
data = data.frame(y = y, grp = grp),
seed = 2020)
#Warning messages:
#1: There were 14 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
#http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
#2: Examine the pairs() plot to diagnose sampling problems
Group-level estimates agree with those from the rstan
models.
fixef(fit3)[, "Estimate"] + ranef(fit3)$grp[, "Estimate", 1]
# 1 2 3
#0.2416648 5.3570510 3.0643358
I remember reading somewhere on the discourse that brms
may already use non-centred parametrisation by default.