Convergence problems (low E-BFMI, high Rhat and divergent transitions) for ordinal regression using a pR2D2M2ord prior

Hi,

I am running into convergence problems while fitting a model I have been working on for some time. I am getting the following diagnostics (with the data for context) :

# Make data V1
data <- list(
  N = nrow(df), # 232
  J = 36,
  K = 20,
  D_y = max(df$crs), # 3
  D_X = max(df[1:20]), # 3
  D_W = max(df[21:40]), # 4
  D_c = max(df$contexte_lowest), # 3
  y = df$crs,
  jj = df$eid %>% factor() %>% as.integer(),
  c = df$contexte_lowest,
  X = t(df[1:20]),
  W = t(df[21:40]),
  hyper = readRDS("data/hyper_config1.RDS"),
  alpha0_mu_p = 1,
  alpha0_phi = 1,
  alpha0_delta = 1
)

# Compile the model
mod <- cmdstanr::cmdstan_model("analyses/stan/v1.stan")

# Run the model
fit <- mod$sample(
  data = data,
  seed = 1234,
  chains = 4,
  parallel_chains = 4,
  iter_warmup = 1000,
  iter_sampling = 2000,
  refresh = 100,
  adapt_delta = 0.99
)

# Diagnostics
fit$diagnostic_summary()
# Warning: 14 of 8000 (0.0%) transitions ended with a divergence.
# See https://mc-stan.org/misc/warnings for details.
# 
# Warning: 23 of 8000 (0.0%) transitions hit the maximum treedepth limit of 10.
# See https://mc-stan.org/misc/warnings for details.
# 
# Warning: 4 of 4 chains had an E-BFMI less than 0.3.
# See https://mc-stan.org/misc/warnings for details.
# $num_divergent
# [1] 7 2 3 2
# 
# $num_max_treedepth
# [1]  0  7  3 13

# $ebfmi
# [1] 0.02310780 0.00984802 0.02017209 0.01262502
fit_sum <- fit$summary(variables = c("mu_p", "tau_p", "kappa", "beta", "lambda", "delta_X", "delta_W", "tau", "xi", "phi"))
> fit_sum[fit_sum$rhat > 1.01,] |> 
+   arrange(desc(rhat))
# A tibble: 1,095 Ă— 10
#    variable     mean   median     sd     mad       q5    q95  rhat ess_bulk ess_tail
#     <chr>       <dbl>    <dbl>  <dbl>   <dbl>    <dbl>  <dbl> <dbl>    <dbl>    <dbl>
#   1 tau       2.92e+0  2.68e+0 1.36   1.14     1.20e+0 5.50    1.14     19.9     49.3
#   2 phi[68]   1.21e-2  8.57e-3 0.0117 0.00901  1.40e-4 0.0351  1.08     34.8     21.3
#   3 lambda[… -3.21e-4  1.68e-4 0.184  0.129   -2.96e-1 0.291   1.06   9995.    1977. 
#   4 phi[54]   1.14e-2  7.88e-3 0.0114 0.00773  6.74e-4 0.0340  1.06    182.     184. 
#   5 lambda[…  6.83e-3  2.03e-3 0.186  0.130   -2.87e-1 0.310   1.05   9742.    1756. 
#   6 phi[64]   1.17e-2  7.83e-3 0.0123 0.00799  7.05e-4 0.0372  1.05    119.     198. 
#   7 lambda[… -5.67e-3 -1.43e-3 0.182  0.129   -3.06e-1 0.288   1.05   9224.    2634. 
#   8 lambda[… -1.11e-2 -2.92e-3 0.191  0.129   -3.20e-1 0.278   1.05   9453.    1882. 
#   9 lambda[…  7.79e-3  2.22e-3 0.184  0.126   -2.84e-1 0.316   1.05  10809.    2079. 
#  10 lambda[… -8.16e-3 -1.79e-3 0.181  0.124   -3.10e-1 0.280   1.05   9429.    2284. 
# ℹ 1,085 more rows
# ℹ Use `print(n = ...)` to see more rows


\hat{R} values seem to indicate that the \tau parameter is the center of the problem. This makes sens since a lot of parameters depend on it, as the global variance of the linear predictor. Am I correct to interpret the low E-BFMI indexes as an indicator that there is not enough data to identifiy the parameters multi-level parameters? Many levels only have one observation.

Here’s the stan code for the model :

functions {
  
  // Induced distribution on cutpoints
  real induced_dirichlet_lpdf(vector kappa, real tau, vector alpha) {
    int K = num_elements(kappa) + 1;
    vector[K - 1] Pi = Phi(kappa / sqrt(1 + tau));
    
    // Induced ordinal probabilities
    vector[K] p = append_row(Pi, [1]') - append_row([0]', Pi);
    
    // Baseline column of Jacobian
    matrix[K, K] J = rep_matrix(0, K, K);
    for (k in 1:K) J[k, 1] = 1;
    
    // Diagnoal entries of Jacobian
    for (k in 2:K) {
      real rho = exp(normal_lpdf(kappa[k-1] | 0, sqrt(1 + tau)));
      J[k, k] = - rho;
      J[k - 1, k] = rho;
    }
    
    return dirichlet_lpdf(p | alpha) + log_determinant(J);
  }
  
  // Inverse Gaussian distribution for global variance tau
  real inv_gauss_lpdf(real x, real mu, real lambda) {
    real lpdf = 0.5 * log(lambda) - 0.5 * log(2 * pi()) - 1.5 * log(x) - lambda * (x - mu)^2 / (2 * mu^2 * x);
    return lpdf;
  }
  
  // Monotonic transform
  row_vector mo(array[] int x, vector scale) {
    int N = num_elements(x);
    int D = num_elements(scale) + 1;
    vector[D] cumsum_scale = append_row([0]', cumulative_sum(scale));
    row_vector[N] eta;
    
    for (n in 1:N) {
      eta[n] = cumsum_scale[x[n]];
    }
    return(eta);
  }
}
data {
    int<lower = 0> N;    // Number of observations (patients)
    int<lower = 1> J;    // Number of levels (evaluators)
    int<lower = 2> K;    // Number of risk factors
    int<lower = 2> D_y;  // Number of categories of the outcome
    int<lower = 2> D_X;  // Number of categories of the risk factors
    int<lower = 2> D_W;  // Number of categories of the moderator
    int<lower = 2> D_c;  // Number of contexts
    array[N] int<lower = 1, upper = D_y> y;     // Observed ordinal outcome (risk estimates)
    array[N] int<lower = 1, upper = J> jj;      // Level index (evaluator)
    array[N] int<lower = 1, upper = D_c> c;     // Context 
    array[K, N] int<lower = 1, upper = D_X> X;  // Risk factor arrays
    array[K, N] int<lower = 1, upper = D_W> W;  // Moderator arrays

    // Hyperparameters
    vector[3] hyper;                           
    real<lower = 0> alpha0_mu_p; 
    real<lower = 0> alpha0_phi;
    real<lower = 0> alpha0_delta;
}
transformed data {
  // Number of regression coefficients
  // 1 for context
  // K for risk factors
  // K for interaction between risk factors and moderators
  int D = 1 + K * 2;
  
}
parameters {
  simplex[D_y] mu_p;    // Population simplex baseline
  real<lower=0> tau_p;  // Population simplex scale
  array[J] ordered[D_y - 1] kappa;  // Individual cut points
  vector[D] beta;             // Regression coefficients (global)
  array[J] vector[D] lambda;  // Regression coefficients (variable)
  simplex[D_c - 1] delta_c;   // Variance allocation parameters for contexts
  array[K] simplex[D_X - 1] delta_X;  // Variance allocation parameters for risk factors' categories 
  array[K] simplex[D_W - 1] delta_W;  // Variance allocation parameters for moderators' categories
  real<lower=0> tau;   // Global variance
  real<lower=0> xi;    // Latent term for global variance
  simplex[D * 2] phi;  // Variance allocation parameters for regression coefficients
}
transformed parameters {
  vector[D_y] alpha = mu_p / tau_p + rep_vector(1, D_y);
  
  matrix[D, N] t_eta; 
  eta[1] = mo(c, delta_c);
  for (k in 1:K) {
    row_vector[N] mo_X = mo(X[k], delta_X[k]);
    eta[1 + k] = mo_X;
    eta[21 + k] = mo_X .* mo(W[k], delta_W[k]);
  }
  matrix[N, D] eta = t_eta';
}
model {
  // Prior model
  // Prior model on cutpoints (kappa)
  mu_p ~ dirichlet(rep_vector(alpha0_mu_p, D_y));
  tau_p ~ normal(0, 1);
  for (j in 1:J) 
    kappa[j] ~ induced_dirichlet(tau, alpha);
  
  // Prior model on regression coefficients (beta, lambda)
  for (d in 1:D) {
    beta[d] ~ normal(0, sqrt(phi[d] * tau));
    for (j in 1:J) {
      lambda[j, d] ~ normal(0, sqrt(phi[D + d] * tau));
    }
  }
  phi ~ dirichlet(rep_vector(alpha0_phi, D * 2));
  if (hyper[1] < -0.5) {
    tau ~ inv_gauss(sqrt(hyper[3] / (hyper[2] + 2 * xi)), hyper[3]);
    xi ~ gamma(-(hyper[1] + 0.5), tau);
  } else {
    tau ~ inv_gauss(sqrt((hyper[3] + 2 * xi) / (hyper[2])), hyper[3] + 2 * xi);
    xi ~ gamma(hyper[1] + 0.5, 1 / tau);
  }
  
  // Prior model on variance allocation parameters (delta_c, delta_X, delta_W)
  delta_c ~ dirichlet(rep_vector(alpha0_delta, D_c - 1));
  for (k in 1:K) {
    delta_X[k] ~ dirichlet(rep_vector(alpha0_delta, D_X - 1));
    delta_W[k] ~ dirichlet(rep_vector(alpha0_delta, D_W - 1));
  }
  
  // Observational model
  for (n in 1:N) {
    y[n] ~ ordered_probit(dot_product(beta, eta[n]) +
                            dot_product(lambda[jj[n]], eta[n]),
                          kappa[jj[n]]);
  }
}

generated quantities {
  vector[N] y_tilde;                      // Latent observations
  array[N] int<lower=1, upper=K> y_pred;  // Posterior predictive samples
  for (n in 1:N) {
    y_tilde[n] = dot_product(beta, eta[n]) +
                   dot_product(lambda[jj[n]], eta[n]);
    y_pred[n] = ordered_probit_rng(y_tilde[n], kappa[jj[n]]);
  }
}


I am aware that it is a very complex model. My goal is to start complex and progressively simplify until diagnostics are satisfactory. Here’s a short bullet list of how the model works :

  • It’s an ordered-probit model combining the R2D2M2 prior (described in * this paper by @javier.aguilar and @paul.buerkner) with the pR2D2ord prior (described in * this article by @eyanchenko).
  • Neither of these prior accounted for multi-level cutpoints so I used the hierarchical cut points described by @betanalpha in this chapter. Note that I had to slightly tweak this implementation to fit the pR2D2ord prior by using the probit link and adding the scaling parameter \tau to the induced dirichlet prior.
  • The model also applies the monotonic transform on all predictors (described in * this article by @paul.buerkner) since they are all ordinal.

What I tried

  • I tried to increase the adapt_delta parameter from 0.8 to 0.99 but the divergencent transitions remain and the E-BFMI index remains the same.
  • I tried to remove the monotonic transforms and standardize the predictors but the diagnostics were not better; still some divergent transitions (9/8000 total), 20% of transitions hitting maximum treedepth limit and similar E-BFMI indexes. Also, \hat{R} for \tau was still the highest at 1.18.

This model takes a long time to sample so I’m limited in how much I can experiment with things like weakly informative priors. If I were to go down that path, what parameters should I target?

Any advice or ideas would be greatly appreciated!

I have found that using a non-centered parameterization on the global and variable regression coefficients (beta and lambda respectively) drastically improves the performance of the model, cutting down sampling time by more than half of what it was with the centered parameterization. Moreover, E-BFMI, rhat and ess_bulk diagnostics are now in the acceptable range for all models, i.e. with or without using monotonic transforms on predictors. However, there are still some divergent transitions remaining.

Divergent transitions seem to arise specifically for really small values of tau_p like tau_p < 0.01). This makes sense since low values of this scaling parameter constrain the prior density of the cutpoints kappa[j]. I could constrain tau_p to values higher than a specific limit (like by setting real<lower 0.01> tau_p;) but that would bias inference. Am I missing something or is this a normal consequence of the dirichlet population model?

1 Like

ping @eyanchenko

Hi Xavier,

Sorry for the slow response as I’ve been traveling. Glad to see you’re trying out the pR2D2 prior.

I did sometimes also get a lot of divergent transitions with the pR2D2 model. However, I usually had good R hat and ESS values. How are those for you now?

You said

In general, I highly recommend the opposite: start with a really simple model, get it to work properly, and then add complexity one at a time. In this case, that could include:

  • using (simulated) data sets with a lot of observations for each level
  • a simpler prior for the cut-points like an ordered normal
  • a simpler prior for tau like tau ~ BP(a,b).

(Yes, I know this won’t induce the correct R2 but I have found that posterior sampling always gets trickier when auxiliary variables are introduced like xi. So getting rid of this first could help identify the problem).

Keep me posted on the progress!

Eric

Hi Eric,

No worries at all. To my surprise, I’ve manage to make good progress on my own!

I think I finally found a (mathematically satisfying) solution to make this model sample smoothly. Yesterday, I tried applying a lognormal prior to tau_p and it basically eliminated all remaining divergences and improved all diagnostics even further (some rhat and ESS values were still close to the unnacceptable range with the previous prior, which is no longer the case).

With the truncated normal(0,1) prior on tau_p, tau_p consistently had the worst Rhat and ESS values, with Rhat values were at 1.05 and ESS_bulk was less than 100 (running 4 chains of 2000 iterations and 1000 warmup). With the lognormal(0,1) prior, Rhat values are at 1.01 and ESS_bulk is above 400.

I say this is “mathematically satisfying” solution because…

  1. it doesn’t involve arbitrarily truncating the prior but instead offers a more “continuous” way to keep the sampler away from extremely small values of tau_p, making the lognormal a reasonable choice of boundary-avoiding prior for this kind of problem.
  2. it centers prior probability around 1, which is a reasonable baseline for a precision\scaling parameter.

Let me know if you have any counter-arguments. The biggest one I can think of is that a lognormal(0,1) might be too informative and thus bias inference that way. I’ve ran some simple prior simulation and it looked to me like this is a decently flexible prior. What I haven’t done yet is to conduct a rigorous sensibility analysis to see how changing the scale of the lognormal prior on tau_p impacts inference. Not sure if I’ll have time to get to that either tbh. For now, good convergence diagnostics will have to suffice – at least until someone who knows better comes along and explains why this would be a terrible prior choice.