De-meaning predictors for hierarchical regression

Background and Simple Case

Consider a univariate linear regression model

y_i\sim\mathsf{Normal}\left(a + b x_i,\sigma^2\right)

with n outcomes y_i, intercept a, slope b, predictors x_i, and known observation variance \sigma^2. If the predictors have non-zero mean \bar x and the priors are weak, the intercept a competes with b because b\bar x can shift all outcomes just like the intercept. Subtracting the mean and repeating the analysis with predictors z_i = x_i - \bar x resolves the issue, as discussed in the Stan docs.

Question: How to Generalize for Hierarchical Regression

What is the best way to generalize de-meaning predictors to hierarchical regression models? I’m particularly interested in models where the hierarchy is not nested but am sticking to a single level in the discussion below.

Consider

y_{ij}\sim\mathsf{Normal}\left(a+b x_i + \alpha_{g_i} + \beta_{g_i} x_i, \sigma^2\right),

where g_i is the group of subject i and \alpha and \beta are group-specific intercepts and slopes, respectively. Here’s some synthetic data for three groups and five subjects per group.

Python code to generate data and figure.
# Generate synthetic data.
np.random.seed(0)
n_subjects_per_group = 5
sigma = 0.5

a = np.array(1.0)
b = np.array(3.0)
alpha = np.array([-1.0, 1.0, 4.0])
beta = np.array([-4.0, 0.0, 2.0])
n_groups, = alpha.shape
n_subjects = n_groups * n_subjects_per_group

idx = np.repeat(np.arange(n_groups), n_subjects_per_group)
x = (
    np.array([0, 5, 10])[idx]
    + np.tile(np.linspace(-1, 1, n_subjects_per_group), n_groups)
)
y = np.random.normal(a + alpha[idx] + (b + beta[idx]) * x, sigma)

# Show data.
fig, ax = plt.subplots()
ax.scatter([], [], label=f"global: a = {a}, b = {b}, sigma = {sigma}", c="w")
for i in range(n_groups):
    fltr = idx == i
    ax.scatter(x[fltr], y[fltr], label=f"group {i + 1}: alpha = {alpha[i]}, beta = {beta[i]}")
ax.set_xlabel("predictor $x$")
ax.set_ylabel("outcome $y$")
ax.legend()

Here is a sequence of models I’ve played with to fit this very simple model and reduce the condition number of the posterior correlation. In short,

  1. Shrinkage priors on group-level variables with globally de-meaned features are a starting point.
  2. Adding group-level de-meaning improves fitting but still leaves much to be desired.
  3. Using a centered parameterization for the coefficients works reasonably well because the data are very informative. But I’m primarily interested in non-nested effects which makes a centered parameterization difficult to apply.

I’d love to get your input on how to further improve the fitting.

Model 1: Regularizing with Shrinkage Priors

Without priors, the likelihood above is not sufficient to identify parameters, e.g., a and \alpha are degenerate. We might add weak priors for the global parameters and shrinkage priors for the group-specific parameters as in the Stan model below.

Stan model with shrinkage priors.
data {
    int n_subjects, n_groups;
    // Predictors z already has the global mean subtracted.
    vector [n_subjects] z, y;
    array [n_subjects] int<lower=1, upper=n_groups> idx;
    real<lower=0> sigma;
}

parameters {
    real a, b;
    vector [n_groups] alpha, beta;
    real <lower=0> lmbd_alpha, lmbd_beta;
}

transformed parameters {
    vector [n_subjects] y_hat = a + b * z + alpha[idx] + beta[idx] .* z;
}

model {
    lmbd_alpha ~ normal(0, 100);
    lmbd_beta ~ normal(0, 100);
    a ~ normal(0, 100);
    b ~ normal(0, 100);
    alpha ~ normal(0, lmbd_alpha);
    beta ~ normal(0, lmbd_beta);
    y ~ normal(y_hat, sigma);
}

I drew some posterior samples using cmdstanpy.CmdStanModel.laplace_sample because HMC is prohibitively slow. Plotting a heatmap of the posterior correlation matrix gives us the following.

The structure looks pretty much like what we’d expect: a is anti-correlated with \alpha because they can compensate for each other. All the \alpha s are correlated with one another because they need to move in unison to compensate for changes in a. The same idea applies to b and \beta. Sampling this using HMC is hard because the condition number is large: 3.4\times10^4.

Model 2: Removing Group-Level Means

While the degeneracy is dominated by the a-\alpha and b-\beta pairs, there is also competition between group-level intercepts \alpha_g and \beta_g \bar x_g, where \bar x_g is the mean of the predictor within group g. Taking of the group-specific mean in the regression yields the following model.

Stan model with shrinkage priors and group-level de-meaning.
data {
    int n_subjects, n_groups;
    // Predictors z already has the global mean subtracted.
    vector [n_subjects] z, y;
    array [n_subjects] int<lower=1, upper=n_groups> idx;
    real<lower=0> sigma;
}

transformed data {
    vector [n_groups] z_group_mean = zeros_vector(n_groups);
    vector [n_groups] z_group_count = zeros_vector(n_groups);
    for (i in 1:n_subjects) {
        z_group_mean[idx[i]] += z[i];
        z_group_count[idx[i]] += 1;
    }
    z_group_mean ./= z_group_count;
}

parameters {
    real a, b;
    vector [n_groups] alpha, beta;
    real <lower=0> lmbd_alpha, lmbd_beta;
}

transformed parameters {
    vector [n_subjects] y_hat = a + b * z + alpha[idx] 
        + beta[idx] .* (z - z_group_mean[idx]);
}

model {
    lmbd_alpha ~ normal(0, 100);
    lmbd_beta ~ normal(0, 100);
    a ~ normal(0, 100);
    b ~ normal(0, 100);
    alpha ~ normal(0, lmbd_alpha);
    beta ~ normal(0, lmbd_beta);
    y ~ normal(y_hat, sigma);
}

The resulting posterior correlation has a smaller condition number and exhibits more structure. We still have the a-\alpha and b-\beta competition. But there is now also structure for the \alpha-b block: When b increases, group-level intercepts \alpha for groups with group mean \bar x_g larger than the global mean \bar g must decrease and vice versa.

Model 3: Centered Parameterization

Combined with group-level de-meaning, we can also use a centered parameterization here because the data are very informative, i.e.,

\begin{aligned} y_i&\sim\mathsf{Normal}\left(\alpha_{g_i} + \beta_{g_i} \left(x_i-\bar x_{g_i}\right), \sigma^2\right)\\ \alpha_i&\sim\mathsf{Normal}\left(a, \lambda_\alpha^2\right)\\ \beta_i&\sim\mathsf{Normal}\left(b, \lambda_\beta^2\right)\\ a&\sim\mathsf{Normal}\left(0, 100^2\right)\\ b&\sim\mathsf{Normal}\left(0, 100^2\right) \end{aligned}
Model with centered parameterization.
data {
    int n_subjects, n_groups;
    // Predictors z already has the global mean subtracted.
    vector [n_subjects] z, y;
    array [n_subjects] int<lower=1, upper=n_groups> idx;
    real<lower=0> sigma;
}

transformed data {
    vector [n_groups] z_group_mean = zeros_vector(n_groups);
    vector [n_groups] z_group_count = zeros_vector(n_groups);
    for (i in 1:n_subjects) {
        z_group_mean[idx[i]] += z[i];
        z_group_count[idx[i]] += 1;
    }
    z_group_mean ./= z_group_count;
}

parameters {
    real a, b;
    vector [n_groups] alpha, beta;
    real <lower=0> lmbd_alpha, lmbd_beta;
}

transformed parameters {
    vector [n_subjects] y_hat = alpha[idx] + beta[idx] .* (z - z_group_mean[idx]);
}

model {
    lmbd_alpha ~ normal(0, 100);
    lmbd_beta ~ normal(0, 100);
    a ~ normal(0, 100);
    b ~ normal(0, 100);
    alpha ~ normal(a, lmbd_alpha);
    beta ~ normal(b, lmbd_beta);
    y ~ normal(y_hat, sigma);
}

This gives a nice condition number for the posterior correlation (although there are some funnels that need working out).

Thanks for making it this far. What’s a good approach for this problem with informative data but a non-nested hierarchy?

Your de-meaning is usually called “centering,” I believe. If you also divide by the standard deviation, it’s called “standardizing.” If you standardize covariates, it’s easier to formualte the scale for default priors. If you have vector-based covariates, you can also use a QR-decomposition to get an orthonormal representation of the covariates. You can then reverse it later to provide the results on the original scale. This is what rstanarm and I believe brms does in order to provide default priors and reduce poor conditioning.

I would suggest using a sum-to-zero constraint on the random effects \alpha and \beta. Without that, you get an extra source of non-identifiability. The new sum-to-zero implementation in the latest Stan is nice because it addresses the conditioning to try to make things isotropic.