Background and Simple Case
Consider a univariate linear regression model
with n outcomes y_i, intercept a, slope b, predictors x_i, and known observation variance \sigma^2. If the predictors have non-zero mean \bar x and the priors are weak, the intercept a competes with b because b\bar x can shift all outcomes just like the intercept. Subtracting the mean and repeating the analysis with predictors z_i = x_i - \bar x resolves the issue, as discussed in the Stan docs.
Question: How to Generalize for Hierarchical Regression
What is the best way to generalize de-meaning predictors to hierarchical regression models? I’m particularly interested in models where the hierarchy is not nested but am sticking to a single level in the discussion below.
Consider
where g_i is the group of subject i and \alpha and \beta are group-specific intercepts and slopes, respectively. Here’s some synthetic data for three groups and five subjects per group.
Python code to generate data and figure.
# Generate synthetic data.
np.random.seed(0)
n_subjects_per_group = 5
sigma = 0.5
a = np.array(1.0)
b = np.array(3.0)
alpha = np.array([-1.0, 1.0, 4.0])
beta = np.array([-4.0, 0.0, 2.0])
n_groups, = alpha.shape
n_subjects = n_groups * n_subjects_per_group
idx = np.repeat(np.arange(n_groups), n_subjects_per_group)
x = (
np.array([0, 5, 10])[idx]
+ np.tile(np.linspace(-1, 1, n_subjects_per_group), n_groups)
)
y = np.random.normal(a + alpha[idx] + (b + beta[idx]) * x, sigma)
# Show data.
fig, ax = plt.subplots()
ax.scatter([], [], label=f"global: a = {a}, b = {b}, sigma = {sigma}", c="w")
for i in range(n_groups):
fltr = idx == i
ax.scatter(x[fltr], y[fltr], label=f"group {i + 1}: alpha = {alpha[i]}, beta = {beta[i]}")
ax.set_xlabel("predictor $x$")
ax.set_ylabel("outcome $y$")
ax.legend()
Here is a sequence of models I’ve played with to fit this very simple model and reduce the condition number of the posterior correlation. In short,
- Shrinkage priors on group-level variables with globally de-meaned features are a starting point.
- Adding group-level de-meaning improves fitting but still leaves much to be desired.
- Using a centered parameterization for the coefficients works reasonably well because the data are very informative. But I’m primarily interested in non-nested effects which makes a centered parameterization difficult to apply.
I’d love to get your input on how to further improve the fitting.
Model 1: Regularizing with Shrinkage Priors
Without priors, the likelihood above is not sufficient to identify parameters, e.g., a and \alpha are degenerate. We might add weak priors for the global parameters and shrinkage priors for the group-specific parameters as in the Stan model below.
Stan model with shrinkage priors.
data {
int n_subjects, n_groups;
// Predictors z already has the global mean subtracted.
vector [n_subjects] z, y;
array [n_subjects] int<lower=1, upper=n_groups> idx;
real<lower=0> sigma;
}
parameters {
real a, b;
vector [n_groups] alpha, beta;
real <lower=0> lmbd_alpha, lmbd_beta;
}
transformed parameters {
vector [n_subjects] y_hat = a + b * z + alpha[idx] + beta[idx] .* z;
}
model {
lmbd_alpha ~ normal(0, 100);
lmbd_beta ~ normal(0, 100);
a ~ normal(0, 100);
b ~ normal(0, 100);
alpha ~ normal(0, lmbd_alpha);
beta ~ normal(0, lmbd_beta);
y ~ normal(y_hat, sigma);
}
I drew some posterior samples using cmdstanpy.CmdStanModel.laplace_sample
because HMC is prohibitively slow. Plotting a heatmap of the posterior correlation matrix gives us the following.
The structure looks pretty much like what we’d expect: a is anti-correlated with \alpha because they can compensate for each other. All the \alpha s are correlated with one another because they need to move in unison to compensate for changes in a. The same idea applies to b and \beta. Sampling this using HMC is hard because the condition number is large: 3.4\times10^4.
Model 2: Removing Group-Level Means
While the degeneracy is dominated by the a-\alpha and b-\beta pairs, there is also competition between group-level intercepts \alpha_g and \beta_g \bar x_g, where \bar x_g is the mean of the predictor within group g. Taking of the group-specific mean in the regression yields the following model.
Stan model with shrinkage priors and group-level de-meaning.
data {
int n_subjects, n_groups;
// Predictors z already has the global mean subtracted.
vector [n_subjects] z, y;
array [n_subjects] int<lower=1, upper=n_groups> idx;
real<lower=0> sigma;
}
transformed data {
vector [n_groups] z_group_mean = zeros_vector(n_groups);
vector [n_groups] z_group_count = zeros_vector(n_groups);
for (i in 1:n_subjects) {
z_group_mean[idx[i]] += z[i];
z_group_count[idx[i]] += 1;
}
z_group_mean ./= z_group_count;
}
parameters {
real a, b;
vector [n_groups] alpha, beta;
real <lower=0> lmbd_alpha, lmbd_beta;
}
transformed parameters {
vector [n_subjects] y_hat = a + b * z + alpha[idx]
+ beta[idx] .* (z - z_group_mean[idx]);
}
model {
lmbd_alpha ~ normal(0, 100);
lmbd_beta ~ normal(0, 100);
a ~ normal(0, 100);
b ~ normal(0, 100);
alpha ~ normal(0, lmbd_alpha);
beta ~ normal(0, lmbd_beta);
y ~ normal(y_hat, sigma);
}
The resulting posterior correlation has a smaller condition number and exhibits more structure. We still have the a-\alpha and b-\beta competition. But there is now also structure for the \alpha-b block: When b increases, group-level intercepts \alpha for groups with group mean \bar x_g larger than the global mean \bar g must decrease and vice versa.
Model 3: Centered Parameterization
Combined with group-level de-meaning, we can also use a centered parameterization here because the data are very informative, i.e.,
Model with centered parameterization.
data {
int n_subjects, n_groups;
// Predictors z already has the global mean subtracted.
vector [n_subjects] z, y;
array [n_subjects] int<lower=1, upper=n_groups> idx;
real<lower=0> sigma;
}
transformed data {
vector [n_groups] z_group_mean = zeros_vector(n_groups);
vector [n_groups] z_group_count = zeros_vector(n_groups);
for (i in 1:n_subjects) {
z_group_mean[idx[i]] += z[i];
z_group_count[idx[i]] += 1;
}
z_group_mean ./= z_group_count;
}
parameters {
real a, b;
vector [n_groups] alpha, beta;
real <lower=0> lmbd_alpha, lmbd_beta;
}
transformed parameters {
vector [n_subjects] y_hat = alpha[idx] + beta[idx] .* (z - z_group_mean[idx]);
}
model {
lmbd_alpha ~ normal(0, 100);
lmbd_beta ~ normal(0, 100);
a ~ normal(0, 100);
b ~ normal(0, 100);
alpha ~ normal(a, lmbd_alpha);
beta ~ normal(b, lmbd_beta);
y ~ normal(y_hat, sigma);
}
This gives a nice condition number for the posterior correlation (although there are some funnels that need working out).
Thanks for making it this far. What’s a good approach for this problem with informative data but a non-nested hierarchy?