I have a data matrix with ~300 observations and ~300 features, and a significant percentage of missing data. The observations are distributed across 3 cohorts, and all observations are marked as belonging to one of 2 groups.
Having previously cut myself on unnecessarily complicated models, I thought I would try something exceedingly basic, namely a hierarchical Gaussian model, so the model assumes that each feature of each observation is drawn from a Gaussian with a unique mean and scale for each cohort x group combination, that mean is then drawn from a Gaussian dependent on the group alone, and those means are finally drawn from some global population distribution. So there is variation across the groups, and then across the cohorts within each group.
The problem is that this model completely fails to mix despite its apparent simplicity. It’s painfully slow and seems to encounter a significant number of divergences and the traces are “step-like”, with each chain being close to constant, but with each chain stuck in a different location.
The Stan code below is designed to only apply the likelihood to the observed elements of the data matrix, as indicated by the binary iscorrupt
array.
/*
generative model with missing values.
Generative model is an additive combination of
global population mean component
group hierarchical mean
cohort x group hierarchical mean
*/
data {
// switch off/on prior and likelihood
int include_prior;
int include_likelihood;
int<lower=0> num_observations;
int<lower=0> num_features;
int<lower=0> num_groups;
int<lower=0> num_cohorts;
vector[num_features] observations[num_observations];
int<lower=1,upper=num_groups> group[num_observations];
int<lower=1,upper=num_cohorts> cohort[num_observations];
int iscorrupt[num_observations,num_features];
}
transformed data {
int num_corrupt_features[num_observations];
int total_corrupt = 0;
// count corrupted features
for (n in 1:num_observations) {
num_corrupt_features[n] = 0;
for (p in 1:num_features) {
num_corrupt_features[n] += iscorrupt[n,p];
}
total_corrupt += num_corrupt_features[n];
}
}
parameters {
vector[num_features] population_mean;
real<lower=0> population_scale;
// group
vector[num_features] group_mean_tilde[num_groups];
real<lower=0> group_scale[num_groups];
// cohort x group
vector[num_features] cohort_group_mean_tilde[num_cohorts, num_groups];
vector<lower=0>[num_features] cohort_group_scale[num_cohorts, num_groups];
}
transformed parameters {
vector[num_features] group_mean[num_groups];
vector[num_features] cohort_group_mean[num_cohorts, num_groups];
matrix[num_features, num_observations] group_mean_component;
matrix[num_features, num_observations] cohort_group_mean_component;
matrix[num_features, num_observations] cohort_group_scale_component;
matrix[num_features, num_observations] mean_component;
matrix[num_features, num_observations] completion;
for (g in 1:num_groups) {
group_mean[g] = population_mean + population_scale * group_mean_tilde[g];
for (c in 1:num_cohorts) {
cohort_group_mean[c,g] = group_mean[g] + group_scale[g] * cohort_group_mean_tilde[c,g];
}
}
mean_component = rep_matrix(population_mean, num_observations);
for (n in 1:num_observations) {
group_mean_component[,n] = group_mean[group[n]];
cohort_group_mean_component[,n] = cohort_group_mean[cohort[n], group[n]];
cohort_group_scale_component[,n] = cohort_group_scale[cohort[n], group[n]];
}
completion = cohort_group_mean_component;
}
model {
if (include_prior) {
population_mean ~ normal(0,1);
population_scale ~ normal(0,1);
for (g in 1:num_groups) {
group_mean_tilde[g] ~ normal(0,1);
group_scale[g] ~ normal(0,1);
for (c in 1:num_cohorts) {
cohort_group_mean_tilde[c,g] ~ normal(0,1);
cohort_group_scale[c,g] ~ normal(0,1);
}
}
}
if (include_likelihood) {
for (n in 1:num_observations) {
// construct vectors with feature indices for (un)corrupted features
int observed[num_features-num_corrupt_features[n]];
int corrupted[num_corrupt_features[n]];
int pos_corrupt = 1;
int pos_obs = 1;
for (p in 1:num_features) {
if (iscorrupt[n,p]) {
corrupted[pos_corrupt] = p;
pos_corrupt += 1;
} else {
observed[pos_obs] = p;
pos_obs += 1;
}
}
to_vector(observations[n, observed]) ~ normal(completion[observed,n], cohort_group_scale_component[observed,n]);
}
}
}
I also tried removing one or both hierarchical levels and a centered parameterization (applied to both levels, I am unsure what level is more relevant), but with no luck. Sampling from the prior alone works well enough, but conditioning on one of the prior samples does not seem to help matters.
This seems to indicate that there might be a problem with the likelihood.