Hi,
so here’s an issue I’ve ran into in a couple analyses where simple uses of varying (random) intercepts resulted in problematic posterior distributions. In particular, the posterior for the overall intercept is negatively correlated with the posterior for all varying intercepts. The problem seems quite fundamental and appears already for simple models, but I have not yet seen it discussed elsewhere. Maybe I am missing something basic?
It appears one can get rid of the issue by enforcing a sum-to-zero constraint on the varying intercepts. Is there other/better solution? I don’t think I currently understand the problem well, so I’d be happy for any feedback (or links to other people describing/solving the problem). If you want a single file with all the models and code discussed below, you can get it at: https://github.com/martinmodrak/blog/blob/master/content/post/2021-intercept-correlations.Rmd
Similar problem is also described here: Low bulk ess simple random intercept model - #2 by martinmodrak
Tagging @paul.buerkner and @jonah as this appears to affect quite a lot of brms/rstanarm models, so maybe you’ve seen it before and can link me to resources.
Thanks for any hints.
The setup
The problem can be demonstrated with a very simple Poisson regression with one varying intercept and just two observations per group. Let’s simulate some data:
set.seed(5422335)
N <- 40
N_groups <- 20
groups <- rep(1:N_groups, length.out = N)
intercept <- 3
group_sd <- 1
group_r <- rnorm(N_groups, sd = group_sd)
mu <- intercept + group_r[groups]
y <- rpois(N, exp(mu))
data_stan <- list(N = N,
N_groups = N_groups,
groups = groups,
y = y)
Note that with intercept 3 (on the log scale) even a single Poisson observation is already quite informative. The problem in my experience only appears when the data provide at least moderate amount of information about the parameters and has more pronounced consequences for convergence/sampling effectiveness when there is more information (e.g. when increasing number of observations per group in the code above).
I’ll also note that the problem manifests as well when using fewer groups, but I find it more puzzling with more groups.
Standard non-centered parametrization
Since non-centered is often the default, let’s start here. My Stan code for this case is:
data {
int<lower=0> N;
int<lower=1> N_groups;
int<lower=1,upper=N_groups> groups[N];
int<lower=0> y[N];
}
parameters {
real intercept;
vector[N_groups] group_z;
real<lower=0> group_sd;
}
transformed parameters {
vector[N_groups] group_r = group_z * group_sd;
}
model {
intercept ~ normal(3, 1);
group_z ~ std_normal();
group_sd ~ normal(0, 1);
y ~ poisson_log(intercept + group_r[groups]);
}
The model fits without issues, with bulk ESSs between 500 - 1000 but there is definitely weird structure in the posterior - here’s a pairs plot:
Note the negative correlation between intercept
and all the group_z
parameters (and the positive correlations between group_z
themselves. This can in my experience cause actual convergence problems if the varying intercepts are embedded in a more complex model than the simple example presented here.
Centered parametrization
So we remember from Mike Betancourt’s case study (Hierarchical Modeling) that when the data inform the varying intercepts very well (which might be the case here), it might be better to use a centered parametrization. So let’s try that:
data {
int<lower=0> N;
int<lower=1> N_groups;
int<lower=1,upper=N_groups> groups[N];
int<lower=0> y[N];
}
parameters {
real intercept;
vector[N_groups] group_r;
real<lower=0> group_sd;
}
model {
intercept ~ normal(3, 1);
group_r ~ normal(0, group_sd);
group_sd ~ normal(0, 1);
y ~ poisson_log(intercept + group_r[groups]);
}
Once again the model fits without big issues (although there are a couple rhats just above 1.01), however the bulk ESSs decrease to 300 - 800 for most parameters and the correlations don’t go away:
Sum to zero, non-centered
This is where stuff gets interesting. Maybe the real problem is that all the varying intercepts increasing a bit and the overall intercept decreasing the same bit produces the same likelihood and the only thing that’s preventing this from becoming a full-blown non-identifiability is the prior. I realized this might help after seeing that R-INLA/inlabru offer a sum to zero constraint as an option for all types of structures, including varying intercepts.
So we may change the model to force the vector of varying intercepts to sum exactly to zero, effectively removing one degree of freedom from the model. Here, I am using the QR parametrization from Test: Soft vs Hard sum-to-zero constrain + choosing the right prior for soft constrain - #31 by andre.pfeuffer (via @aaronjg and @andre.pfeuffer), but even a soft constraint (as discussed in the linked thread) seems to help the real models I’ve worked with noticeably.
I also get back to non-centered parametrization. The Stan code is:
functions {
vector Q_sum_to_zero_QR(int N) {
vector [2*N] Q_r;
for(i in 1:N) {
Q_r[i] = -sqrt((N-i)/(N-i+1.0));
Q_r[i+N] = inv_sqrt((N-i) * (N-i+1));
}
Q_r = Q_r * inv_sqrt(1 - inv(N));
return Q_r;
}
vector sum_to_zero_QR(vector x_raw, vector Q_r) {
int N = num_elements(x_raw) + 1;
vector [N] x;
real x_aux = 0;
for(i in 1:N-1){
x[i] = x_aux + x_raw[i] * Q_r[i];
x_aux = x_aux + x_raw[i] * Q_r[i+N];
}
x[N] = x_aux;
return x;
}
}
data {
int<lower=0> N;
int<lower=1> N_groups;
int<lower=1,upper=N_groups> groups[N];
int<lower=0> y[N];
}
transformed data {
vector[2 * N_groups] groups_Q_r = Q_sum_to_zero_QR(N_groups);
}
parameters {
real intercept;
vector[N_groups - 1] group_r_raw;
real<lower=0> group_sd;
}
transformed parameters {
vector[N_groups] group_r = sum_to_zero_QR(group_r_raw, groups_Q_r) * group_sd;
}
model {
intercept ~ normal(3,1);
group_r_raw ~ normal(0, 1);
group_sd ~ normal(0, 1);
y ~ poisson_log(intercept + group_r[groups]);
}
No fitting issues, bulk ESSs increase to 600 - 7000 and the pairs plot looks nice:
Now this changes the interpretation of the model. If I get it right, the fitted intercept is no longer the mean of the assumed population of groups, but rather the mean of the actual observed set of groups. This means that making predictions for new, unobserved groups will be a bit more tricky. And maybe there are other potential problems I am missing?
Also the real question is: this looks like it should be a quite common problem. So why isn’t this a well known thing? The simplest explanation seems to be that it is me who is missing something basic, so I’ll be happy to learn what is it?
It also feels like the problem should be less pronounced as I add more groups (as increasing all the group intercepts should incur larger penalty), but I can easily have 200 groups in this example and still see basically the same issue.
Sum to zero, centered
For this problem, the centered version of the sum-to-zero model also works well, with even better bulk ESS (1500 - 8000). Not showing the details for brevity (the code link has this model as well), but for real models where I encountered this, I needed to use both centered parametrization and sum-to-zero to get rid of convergence issues, but I assume those are just two orthogonal problems and maybe the data I’ve worked with just demonstrate both issues.