I’m having trouble generalizing my multivariate cumulative logit regression model. The basic structure of the model is atypical so I’m not sure what I’m doing wrong. I managed to write a working non-hierarchical version of model which doesn’t have any convergence issues for (at least) up to 20 predictor-moderator pairs. Here it is if you want to try it yourself :
test_v2.stan (2.6 KB)
My problem is that the hierarchical version (see below) is really unstable. From what I can tell, the hierarchical model does work as intended but I may be missing something. Here’s the code :
functions {
real induced_dirichlet_lpdf(vector kappa, vector alpha, real phi) {
int K = num_elements(kappa) + 1;
vector[K - 1] sigma = inv_logit(phi - kappa);
vector[K] p;
matrix[K, K] J = rep_matrix(0, K, K);
// Induced ordinal probabilities
p[1] = 1 - sigma[1];
for (k in 2:(K - 1))
p[k] = sigma[k - 1] - sigma[k];
p[K] = sigma[K - 1];
// Baseline column of Jacobian
for (k in 1:K) J[k, 1] = 1;
// Diagonal entries of Jacobian
for (k in 2:K) {
real rho = sigma[k - 1] * (1 - sigma[k - 1]);
J[k, k] = - rho;
J[k - 1, k] = rho;
}
return dirichlet_lpdf(p | alpha)
+ log_determinant(J);
}
real mo(vector scale, int i) {
if (i == 1) {
return 0;
} else {
return sum(scale[1:(i - 1)]);
}
}
}
data {
int<lower=0> N; // Number of observations (patients)
int<lower = 1> J; // Number of levels (evaluators)
int<lower = 0> K; // Number of predictor-moderator pairs
int<lower = 3> D_y; // Number of ordinal categories of the outcome
int<lower = 3> D_x; // Number of ordinal categories of the predictors
int<lower = 3> D_w; // Number of ordinal categories of the moderators
array[N] int<lower=1, upper=3> y; // Observed ordinal outcome (risk estimates)
array[N] int<lower=1, upper=J> j; // Level (evaluator) index
array[N,K] int<lower=1, upper=3> X; // predictor matrix
array[N,K] int<lower=1, upper=3> W; // moderator matrix
}
parameters {
array[J] ordered[D_y - 1] kappa; // (Internal) cut points for the outcome per level
array[K] real mu_beta; // Means of the predictors' latent effects
array[K] real<lower=0> tau_beta; // Scales of the predictors' latent effects
array[J] vector[K] beta; // Latent effects of the predictors per level
array[K] real mu_lambda; // Means of the moderators' latent effects
array[K] real<lower=0> tau_lambda; // Scales of the moderators' latent effects
array[J] vector<lower=0>[K] lambda; // Latent effects of the moderators per level
array[K] vector<lower=0>[D_x - 1] alpha_delta; // Prior sample sizes across categories of the predictors
array[J,K] simplex[D_x - 1] delta; // Normalized distances across categories of the predictors per level
array[K] vector<lower=0>[D_w - 1] alpha_zeta; // Prior sample sizes across categories of the moderators
array[J,K] simplex[D_w - 1] zeta; // Normalized distances across categories of the moderators per level
}
model {
// Prior model
for (i in 1:J) {
kappa[i] ~ induced_dirichlet(rep_vector(1, D_y), 0);
}
for (k in 1:K) {
mu_beta[k] ~ normal(0, 1);
tau_beta[k] ~ normal(0, 1);
beta[,k] ~ normal(mu_beta[k], tau_beta[k]);
mu_lambda[k] ~ normal(0, 1);
tau_lambda[k] ~ normal(0, 1);
lambda[,k] ~ normal(mu_lambda[k], tau_lambda[k]);
for (d in 1:(D_x - 1)) {
alpha_delta[k, d] ~ normal(0, 1);
}
delta[,k] ~ dirichlet(alpha_delta[k]);
for (d in 1:(D_w - 1)) {
alpha_zeta[k, d] ~ normal(0, 1);
}
zeta[,k] ~ dirichlet(alpha_zeta[k]);
}
// Observed model
matrix[N, K] eta_x;
matrix[N, K] eta_w;
vector[N] phi;
for(k in 1:K) {
for (n in 1:N) {
eta_x[n, k] = mo(delta[j[n], k], X[n, k]);
eta_w[n, k] = mo(zeta[j[n], k], W[n, k]);
}
}
for(n in 1:N) {
phi[n] = eta_x[n] * beta[j[n]] .* eta_w[n] * lambda[j[n]];
}
for(n in 1:N) {
y[n] ~ ordered_logistic(phi[n], kappa[j[n]]);
}
}
Note. The induced_dirichlet()
function is taken from this vignette from @betanalpha.
I think the problem lies in the independent dirichlet hyperpriors declared on the arrays of simplexes delta
and zeta
. I’m currently trying to implement a reparametrization as described in this section of the Stan user manual but I haven’t quite figured it out yet so help on that would be great.
Of course, the problem could also come from somewhere else in the code so any suggestions regarding statistical or computational optimization would also be greatly appreciated.