Sampling in multi-level model

I have a data matrix with ~300 observations and ~300 features, and a significant percentage of missing data. The observations are distributed across 3 cohorts, and all observations are marked as belonging to one of 2 groups.

Having previously cut myself on unnecessarily complicated models, I thought I would try something exceedingly basic, namely a hierarchical Gaussian model, so the model assumes that each feature of each observation is drawn from a Gaussian with a unique mean and scale for each cohort x group combination, that mean is then drawn from a Gaussian dependent on the group alone, and those means are finally drawn from some global population distribution. So there is variation across the groups, and then across the cohorts within each group.

The problem is that this model completely fails to mix despite its apparent simplicity. It’s painfully slow and seems to encounter a significant number of divergences and the traces are “step-like”, with each chain being close to constant, but with each chain stuck in a different location.

The Stan code below is designed to only apply the likelihood to the observed elements of the data matrix, as indicated by the binary iscorrupt array.

/*
generative model with missing values.
Generative model is an additive combination of
    global population mean component
    group hierarchical mean
    cohort x group hierarchical mean
*/

data {
    // switch off/on prior and likelihood
    int include_prior;
    int include_likelihood;
    
    int<lower=0> num_observations; 
    int<lower=0> num_features;
    int<lower=0> num_groups;
    int<lower=0> num_cohorts;


    vector[num_features] observations[num_observations];
    int<lower=1,upper=num_groups> group[num_observations];
    int<lower=1,upper=num_cohorts> cohort[num_observations];

    int iscorrupt[num_observations,num_features];
}

transformed data {
    int num_corrupt_features[num_observations];
    int total_corrupt = 0;

    // count corrupted features
    for (n in 1:num_observations) {
        num_corrupt_features[n] = 0;
        for (p in 1:num_features) {
            num_corrupt_features[n] += iscorrupt[n,p];
        }
        total_corrupt += num_corrupt_features[n];
    } 
}

parameters {
    vector[num_features] population_mean;
    real<lower=0> population_scale;
    
    // group 
    vector[num_features] group_mean_tilde[num_groups];
    real<lower=0> group_scale[num_groups];
    
    // cohort x group
    vector[num_features] cohort_group_mean_tilde[num_cohorts, num_groups];
    vector<lower=0>[num_features] cohort_group_scale[num_cohorts, num_groups];
}

transformed parameters {
    vector[num_features] group_mean[num_groups];
    vector[num_features] cohort_group_mean[num_cohorts, num_groups];

    matrix[num_features, num_observations] group_mean_component;
    matrix[num_features, num_observations] cohort_group_mean_component;
    matrix[num_features, num_observations] cohort_group_scale_component;
    matrix[num_features, num_observations] mean_component;
    matrix[num_features, num_observations] completion;

    for (g in 1:num_groups) {
        group_mean[g] = population_mean + population_scale * group_mean_tilde[g];
        for (c in 1:num_cohorts) {
            cohort_group_mean[c,g] = group_mean[g] + group_scale[g] * cohort_group_mean_tilde[c,g];
        }
    }

    mean_component = rep_matrix(population_mean, num_observations);
    for (n in 1:num_observations) {
        group_mean_component[,n] = group_mean[group[n]];
        cohort_group_mean_component[,n] = cohort_group_mean[cohort[n], group[n]];
        cohort_group_scale_component[,n] = cohort_group_scale[cohort[n], group[n]];
    } 
       
    completion = cohort_group_mean_component;
}

model {
    if (include_prior) {        
        population_mean ~ normal(0,1);        
        population_scale ~ normal(0,1);
        for (g in 1:num_groups) {
            group_mean_tilde[g] ~ normal(0,1);
            group_scale[g] ~ normal(0,1);
            for (c in 1:num_cohorts) {
                cohort_group_mean_tilde[c,g] ~ normal(0,1);
                cohort_group_scale[c,g] ~ normal(0,1);
            }
        }
    }

    if (include_likelihood) {
        for (n in 1:num_observations) {
            // construct vectors with feature indices for (un)corrupted features
            int observed[num_features-num_corrupt_features[n]];
            int corrupted[num_corrupt_features[n]];
            int pos_corrupt = 1;
            int pos_obs = 1;
            for (p in 1:num_features) {
                if (iscorrupt[n,p]) {
                    corrupted[pos_corrupt] = p;
                    pos_corrupt += 1;
                } else {
                    observed[pos_obs] = p;
                    pos_obs += 1;
                }
            }
            to_vector(observations[n, observed]) ~ normal(completion[observed,n], cohort_group_scale_component[observed,n]);
        }
    } 
}

I also tried removing one or both hierarchical levels and a centered parameterization (applied to both levels, I am unsure what level is more relevant), but with no luck. Sampling from the prior alone works well enough, but conditioning on one of the prior samples does not seem to help matters.

This seems to indicate that there might be a problem with the likelihood.

The likelihood was indeed the problem. Instead of evaluating the likelihood row by row, I vectorized the observations, the mean, and the scale, and I then precomputed the observed indices for the complete vector, instead of calculating it per observation (and per iteration).

Now it is a lot faster and actually produces credible samples. I understand why the redundant recomputation of the observedarray might impact computation time, but can anyone explain why this would affect the sampling behavior? Seems like I am doing the exact same thing, except with varying degrees of vectorization.