Proper use of sum_to_zero_vector in nested multilevel models

@spinkney If you add a normal_sum_to_zero(sigma), I’d consider defining it such that the marginal variance is \sigma^2 (1 - \tfrac{1}{N}). That is how pymc and numpyro define the ZeroSumNormal distribution, and having slightly different definitions around in the libraries could be annoying in the long term. I think that’s also usually what’s needed in applications of it. (At least it is in pretty much all cases where I’m using it, not sure how well that generalizes). It also feels quite natural to me if you look at the eigenvalues of the covariance matrix. One of those eigenvalues is zero, all others are \sigma in this parametrization.

@lukaseamus
With a little bit of work you can make sure you get the exact same results with the zero-sum-normal parametrization, as with the original model. To me, that feels a lot cleaner.

Unless I’m missing something, we can simplify your model, because the two groups don’t share any parameters. Essentially, it fits two separate models, one model for group1 and one model for group2. If we only fit one of those models at a time, it becomes easier to understand how this reparametrization works. If we do this, the model reduces to a GLM data ~ 1 + (1|unit) with log link and gamma likelihood:

data {
    int n_observations;
    int n_units;
    vector[n_observations] values;
    array[n_observations] int<lower=1, upper=n_units> unit_idx;
}
parameters {
    // Intercepts in log space
    real intercept;
    vector[n_units] unit_effect;

    // Inter-unit variability
    real<lower=0> unit_sigma;

    // Likelihood uncertainty
    real<lower=0> sigma;
}
transformed parameters {
    vector[n_observations] mu = exp(
        intercept
        + unit_effect[unit_idx]
    );
}
model {
    intercept ~ normal(log(1), 0.2);
    unit_sigma ~ normal(0, 1);
    unit_effect ~ normal(0, unit_sigma);
    sigma ~ normal(0, 1);

    // Gamma likelihood
    values ~ gamma(
        square(mu) ./ square(sigma),
        mu ./ square(sigma)
    );
}

The problem that the zero sum normal reparametrization fixes also isn’t all that pronounced in your dataset, so I’ll change the number of observations a bit, so that it is easier to see (sorry for switching to python, that’s just way easier for me…):

import polars as pl
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import arviz
import nutpie

def make_data(rng, n_units, n_obs_per_unit, *, intercept, unit_sigma, sigma):
    units = [f"unit_{i}" for i in range(n_units)]
    # Unit-level effects (unique unit IDs across groups)
    unit_effects = pl.DataFrame({
        'unit': units,
        'unit_effect': np.concatenate([
            rng.normal(0, unit_sigma, size=n_units),
        ])
    })

    # Observation-level structure
    observations = pl.DataFrame({
        "observation": [f"obs_{i}" for i in range(n_units * n_obs_per_unit)],
        'unit': np.repeat(units, n_obs_per_unit),
    })


    df = observations.join(unit_effects, on='unit')

    # Generate final values
    adjusted_mean = np.exp(intercept + df['unit_effect'])
    shape = adjusted_mean**2 / sigma**2
    rate = adjusted_mean / sigma**2

    df = df.with_columns([
        pl.Series('value', rng.gamma(shape, 1/rate))
    ])

    # Keep just what we need
    data = df.select(["observation", 'unit', 'value', "unit_effect"])
    return data.to_pandas()
true_values =  {
    "intercept": 0.15,
    "unit_sigma": 0.2,
    "sigma": 0.1,
}

rng = np.random.default_rng(42)
data = make_data(rng, n_units=5, n_obs_per_unit=500, **true_values)

sns.catplot(
    data=data,
    kind="strip",
    x='unit',
    y='value',
)

If we sample this, the sampler (here nutpie, but similar with stan) will struggle quite a bit, and won’t be particularly efficient. It is using very long trajectories quite often (a depth of 8 for instance means, that it used ~2^8=256 gradient evaluations for a single draw).

unit_idx, units = pd.factorize(data["unit"])

coords = {
    "observation": data["observation"],
    "unit": units,
}

data_stan = {
    "n_observations": len(coords["observation"]),
    "n_units": len(coords["unit"]),
    "values": data["value"].values,
    "unit_idx": unit_idx + 1,
}

dims = {
    "unit_effect": ["unit"],
    "mu": ["observation"],
}

trace = nutpie.sample(
    compiled
    .with_data(**data_stan)
    .with_coords(**coords)
    .with_dims(**dims),
    tune=1000,
    draws=1000,
    chains=10,
    seed=42,
)

sns.countplot(trace.sample_stats.depth.to_dataframe(), x="depth", hue="chain")

So why does the sampler struggle with such a simple model and a still pretty small dataset? It is not because of the usual centered / non-centered parametrization. If we were to switch to a non-centered parametrization, it performs even worse.

The problem is that all parameters except for sigma are still quite uncertain, but some linear combinations of the parameters are very certain, which leads to high correlations in the posterior.

Those linear combinations have actually quite nice interpretations here. The way I chose the parameters, we have very precise information about the expected value within each unit. But since we only ever observe 5 units, we never learn much about the distribution of the units.

For instance, we know quite well how big the expected outcome in unit 1 is:

(trace.posterior.intercept + trace.posterior.unit_effect).sel(unit="unit_0").std()
# 0.0036

But at the same time, we are very uncertain about what the unit effect of the first unit itself is:

trace.posterior.unit_effect.sel(unit="unit_0").std()
# 0.13

This is because unit_effect measures how much larger the expected value of the first unit is relative to the population mean of the units! So for instance, if each unit were a medical treatment, it would specify how much better the first teatment is, compared to some (slightly hypothetical) infinite population of treatments. But since we only ever observed 5 units, most of the uncertainty in unit_effect is due to the uncertainty in the population mean:

trace.posterior.intercept.std()
# 0.13

We can make use of a quantity that’s closely related to the population mean of the units though: The sample mean of the 5 units that we did observe. The basic idea of the zero-sum-normal parametrization is to parameterize our model in terms of differences to this sample mean, instead of in terms of differences to the population mean:

unit_effect_sample_mean = trace.posterior.unit_effect - trace.posterior.unit_effect.mean("unit")
unit_effect_rel = trace.posterior.unit_effect - unit_effect_sample_mean

Instead of using unit_effect as a parameter, we would like to use unit_effect_rel instead: The difference between the mean within a unit and the sample mean of the 5 observed units. As a small side-note: I think there are many cases, where this difference is what we want to know in the first place, and not really the difference to the population mean. This is almost always the case, if the population in question isn’t infinite in the first place. So if you have a model where the unit is a US state, I think we should really look at the difference to the other existing states, not the difference to some (hypothetical and arguably nonsensical infinite population of states).

For some more details about how that works, see for instance here.

Here is an implementation of this reparametrization, that also reconstructs the original values for unit_effect and intercept:

data {
    int n_observations;
    int n_units;
    vector[n_observations] values;
    array[n_observations] int<lower=1, upper=n_units> unit_idx;
}
parameters {
    // Intercepts in log space
    real intercept_plus_pop_mean;
    sum_to_zero_vector[n_units] unit_effect_rel;

    // Inter-unit variability
    real<lower=0> unit_sigma;

    // Likelihood uncertainty
    real<lower=0> sigma;
}
transformed parameters {
    vector[n_observations] mu = exp(
        intercept_plus_pop_mean
        + unit_effect_rel[unit_idx]
    );

    real intercept_sigma = 0.2;
    real intercept_plus_pop_mean_sigma = sqrt(intercept_sigma ^ 2 + unit_sigma ^ 2 / n_units);
}
model {
    // Priors
    intercept_plus_pop_mean ~ normal(0, intercept_plus_pop_mean_sigma);
    unit_sigma ~ normal(0, 1);

    unit_effect_rel ~ normal(0, unit_sigma);
    target += log(unit_sigma);

    sigma ~ normal(0, 1);

    // Gamma likelihood
    values ~ gamma(
        square(mu) ./ square(sigma),
        mu ./ square(sigma)
    );
}
generated quantities {
    real intercept;
    {
        real var1 = intercept_sigma ^ 2;
        real var2 = unit_sigma ^ 2 / n_units;
        real total_var = var1 + var2;
        intercept = normal_rng(
            intercept_plus_pop_mean * var1 / total_var,
            sqrt(var1 * var2 / total_var)
        );
    }

    real unit_sample_mean = intercept_plus_pop_mean - intercept;
    vector[n_units] unit_effect = unit_effect_rel + unit_sample_mean;
}

This samples very well, with much smaller treedepth:

and pretty close to ideal effective sample size. It also gives the exact same posterior as the first model.

1 Like