Building a hierarchical multivariate cumulative logit model

I’m having trouble generalizing my multivariate cumulative logit regression model. The basic structure of the model is atypical so I’m not sure what I’m doing wrong. I managed to write a working non-hierarchical version of model which doesn’t have any convergence issues for (at least) up to 20 predictor-moderator pairs. Here it is if you want to try it yourself :
test_v2.stan (2.6 KB)

My problem is that the hierarchical version (see below) is really unstable. From what I can tell, the hierarchical model does work as intended but I may be missing something. Here’s the code :

functions {
    real induced_dirichlet_lpdf(vector kappa, vector alpha, real phi) {
        int K = num_elements(kappa) + 1;
        vector[K - 1] sigma = inv_logit(phi - kappa);
        vector[K] p;
        matrix[K, K] J = rep_matrix(0, K, K);
        
        // Induced ordinal probabilities
        p[1] = 1 - sigma[1];
        for (k in 2:(K - 1))
            p[k] = sigma[k - 1] - sigma[k];
        p[K] = sigma[K - 1];
        
        // Baseline column of Jacobian
        for (k in 1:K) J[k, 1] = 1;
        
        // Diagonal entries of Jacobian
        for (k in 2:K) {
            real rho = sigma[k - 1] * (1 - sigma[k - 1]);
            J[k, k] = - rho;
            J[k - 1, k] = rho;
        }
        
        return   dirichlet_lpdf(p | alpha)
               + log_determinant(J);
    }
    real mo(vector scale, int i) {
        if (i == 1) {
            return 0;
        } else {
            return sum(scale[1:(i - 1)]);
        }
    }
}

data {
    int<lower=0> N;                 // Number of observations (patients)
    int<lower = 1> J;               // Number of levels (evaluators)
    int<lower = 0> K;               // Number of predictor-moderator pairs
    int<lower = 3> D_y;             // Number of ordinal categories of the outcome
    int<lower = 3> D_x;             // Number of ordinal categories of the predictors
    int<lower = 3> D_w;             // Number of ordinal categories of the moderators
    array[N] int<lower=1, upper=3> y;     // Observed ordinal outcome (risk estimates)
    array[N] int<lower=1, upper=J> j;     // Level (evaluator) index
    array[N,K] int<lower=1, upper=3> X;   // predictor matrix
    array[N,K] int<lower=1, upper=3> W;   // moderator matrix
}
parameters {
    array[J] ordered[D_y - 1] kappa;                // (Internal) cut points for the outcome per level
    array[K] real mu_beta;                          // Means of the predictors' latent effects
    array[K] real<lower=0> tau_beta;                // Scales of the predictors' latent effects
    array[J] vector[K] beta;                        // Latent effects of the predictors per level
    array[K] real mu_lambda;                        // Means of the moderators' latent effects
    array[K] real<lower=0> tau_lambda;              // Scales of the moderators' latent effects
    array[J] vector<lower=0>[K] lambda;             // Latent effects of the moderators per level
    array[K] vector<lower=0>[D_x - 1] alpha_delta;  // Prior sample sizes across categories of the predictors
    array[J,K] simplex[D_x - 1] delta;              // Normalized distances across categories of the predictors per level
    array[K] vector<lower=0>[D_w - 1] alpha_zeta;   // Prior sample sizes across categories of the moderators
    array[J,K] simplex[D_w - 1] zeta;               // Normalized distances across categories of the moderators per level
}

model {
    // Prior model
    for (i in 1:J) {
        kappa[i] ~ induced_dirichlet(rep_vector(1, D_y), 0);
    }
    for (k in 1:K) {
        mu_beta[k] ~ normal(0, 1);
        tau_beta[k] ~ normal(0, 1);
        beta[,k] ~ normal(mu_beta[k], tau_beta[k]);        
        mu_lambda[k] ~ normal(0, 1);
        tau_lambda[k] ~ normal(0, 1);
        lambda[,k] ~ normal(mu_lambda[k], tau_lambda[k]);
        for (d in 1:(D_x - 1)) {
            alpha_delta[k, d] ~ normal(0, 1);
        }
        delta[,k] ~ dirichlet(alpha_delta[k]);
        for (d in 1:(D_w - 1)) {
            alpha_zeta[k, d] ~ normal(0, 1);
        }
        zeta[,k] ~ dirichlet(alpha_zeta[k]);
    }
    
    // Observed model
    matrix[N, K] eta_x;
    matrix[N, K] eta_w;
    vector[N] phi;
    
    for(k in 1:K) {
        for (n in 1:N) {
            eta_x[n, k] = mo(delta[j[n], k], X[n, k]);
            eta_w[n, k] = mo(zeta[j[n], k], W[n, k]);
        }
    }
    
    for(n in 1:N) {
        phi[n] = eta_x[n] * beta[j[n]] .* eta_w[n] * lambda[j[n]];
    }
    
    for(n in 1:N) {
        y[n] ~ ordered_logistic(phi[n], kappa[j[n]]);
    }
}


Note. The induced_dirichlet() function is taken from this vignette from @betanalpha.


I think the problem lies in the independent dirichlet hyperpriors declared on the arrays of simplexes delta and zeta. I’m currently trying to implement a reparametrization as described in this section of the Stan user manual but I haven’t quite figured it out yet so help on that would be great.

Of course, the problem could also come from somewhere else in the code so any suggestions regarding statistical or computational optimization would also be greatly appreciated.