Hierarchical MPT, divergent transitions

Hi!

This is a follow-up to the earlier discussion in google groups on divergent transitions in hierarchical mpt in rstan:

https://groups.google.com/g/stan-users/c/ZTiu_ij_bC8

It concerns the divergent transitions of this type of models in rstan:

Please note: The model that I use is a for a different data set and has different mpt parameters but it follows the structure of the case study.

The solution to the problem of divergent transitions that was suggested in that thread was to set the control parameters (adapt_delta, stepsize, treedepth) to extreme values which makes the sampling really slow and does not always remove divergent transitions in complex models.

I have therefore been wondering whether there are other solutions such as change of 1) priors, 2) change of initial values or 3) reparameterization.

Concerning 1):

The MPT_5_Stan.R example model (https://github.com/stan-dev/example-models/tree/master/Bayesian_Cognitive_Modeling/CaseStudies/MPT) uses a sigma ~ cauchy(0, 2.5) prior for sigma. Since the domain restriction enforces positive values, it is really a half-cauchy: vector<lower=0>[nparams] sigma.

I noticed when sampling from the prior predictive distribution of a stan model that follows MPT_5_Stan.R that the sigma ~ cauchy(0, 2.5) prior included the value 54 in its 95 % credible interval and that even the prior predictive fitting (without the likelihood function and the constraint by the data) included divergent transitions.

I eventually got rid of the divergent transitions for the prior predictive distribution by using sigma ~ exponential(1) instead. But this alone did not completely remove the divergent transitions for the posterior predictive distribution.

Concerning 2):

I replaced the initial values in the case studies with the following to ensure that the constraints of the Cholesky decomposition were satisified for the initial values of L_Omega. Again this did not completely remove the divergent transitions but did lead to an improvement.


library(Matrix)
generate_valid_L_Omega ← function(nparams) {

random_matrix ← matrix(runif(nparams^2), nparams, nparams)
sym_matrix ← (random_matrix + t(random_matrix)) / 2
diag(sym_matrix) ← 1

pd_matrix ← nearPD(sym_matrix, corr = TRUE)$mat

L_Omega ← t(chol(pd_matrix))

return(L_Omega)
}

init_set ← function(nsubjs, nparams) {
L_Omega_init ← generate_valid_L_Omega(nparams)

L_init ← list(
deltahat_tilde = matrix(rnorm(nsubjs * nparams), nsubjs, nparams),
L_Omega = L_Omega_init,
sigma = runif(nparams, 0.1, 2),
mu_c_hat = rnorm(1), mu_n_hat = rnorm(1), mu_S_hat = rnorm(1),
mu_I_hat = rnorm(1), mu_ca_hat = rnorm(1), mu_na_hat = rnorm(1)
)

return(L_init)
}

Concerning 3):

I am considering returning to the centered parameterization instead of using the non-centered paramerization in https://github.com/stan-dev/example-models/tree/master/Bayesian_Cognitive_Modeling/CaseStudies/MPT to see whether this makes a different for a large data set.

I here upload attachment of the divergent transitions in one of the models. It does not appear that they take the form of a funnel and so I was wondering how Michael Betancourt @betanalpha would classify them - and whether there are further suggestions for how to remove them?

Finally, here are the number of leapfrog steps:

Leapfrog Steps samples_1 Counts
7 1
19 1
21 1
22 1
23 1
28 1
31 1
34 1
43 2
47 1
48 1
51 1
56 1
58 1
62 1
65 1
67 1
73 1
78 1
83 1
85 1
87 1
91 1
98 1
99 1
104 1
107 1
108 1
109 1
110 1
112 1
113 1
115 1
116 1
121 1
122 2
123 1
124 1
127 150804
133 1
138 1
142 1
151 1
172 1
179 1
185 1
188 1
191 1
224 1
225 1
226 1
235 2
241 2
253 1
255 149139

That makes sense. We’ve gone away from recommending the Cauchy priors for just this reason.

Were there problems if you just let Stan automatically initialize each dimension uniform(-2, 2) on the unconstrained scale? It should guarantee you at least get a valid Cholesky factor. This can be trickier in high dimensions due to arithmetic error. Or was it just to get closer to a reasonable value so warmup struggled less to converge?

The general rule of thumb is to use centering when the posterior is well identified and non-centering when it’s not. That usually means using centered parameterization when there is a lot of data and/or very strong prior and use non-centered otherwise. It’s easy enough to check, though with lots of hierarchical components, it gets combinatorial to try all combinations.

I’m not sure what that’s plotting. Are the red points the divergent transitions? If so, that looks really bad!

You’re probably not getting much out of the cases with 100K+ leapfrog steps and that would have required a max tree depth of 14+.


A lot of the code in that example can be sped up with vectorization. And I think you can remove a lot of the redundancy, which is expensive in terms of copies sand memory pressure. For example,

transformed parameters {
  simplex[4] theta[nsubjs];
  vector<lower=0,upper=1>[nsubjs] c;
  vector<lower=0,upper=1>[nsubjs] r;
  vector<lower=0,upper=1>[nsubjs] u;
  
  matrix[nsubjs,nparams] deltahat; 
  vector[nsubjs] deltachat;
  vector[nsubjs] deltarhat;
  vector[nsubjs] deltauhat;

  deltahat <- (diag_pre_multiply(sigma, L_Omega) * deltahat_tilde)'; 

  for (i in 1:nsubjs) {
    
    deltachat[i] <- deltahat[i,1];
    deltarhat[i] <- deltahat[i,2];
    deltauhat[i] <- deltahat[i,3];
    
    // Probitize Parameters c, r, and u 
    c[i] <- Phi(muchat + deltachat[i]);
    r[i] <- Phi(murhat + deltarhat[i]);
    u[i] <- Phi(muuhat + deltauhat[i]);
    
    // MPT Category Probabilities for Word Pairs
    theta[i,1] <- c[i] * r[i];
    theta[i,2] <- (1 - c[i]) * (u[i]) ^ 2;
    theta[i,3] <- (1 - c[i]) * 2 * u[i] * (1 - u[i]);
    theta[i,4] <- c[i] * (1 - r[i]) + (1 - c[i]) * (1 - u[i]) ^ 2;
  }
}

can be simplified and made more efficient by vectorizing it all as follows.

transformed parameters {
  matrix[nsubjs,nparams] deltahat
    = (diag_pre_multiply(sigma, L_Omega) * deltahat_tilde)'; 
  vector<lower=0,upper=1>[nsubjs] c = Phi(muchat + deltahat[ , 1]);
  vector<lower=0,upper=1>[nsubjs] r = Phi(murhat + deltahat[ , 2]);
  vector<lower=0,upper=1>[nsubjs] u = Phi(muuhat + deltahat[ , 3])

  simplex[4] theta[nsubjs];
  theta[ , 1] = c .* r;
  theta[ , 2] = (1 - c) .* u^2;
  theta[ , 3] = (1 - c) * 2 .* u * (1 - u);
  theta[ , 4] = c .* (1 - r) + (1 - c) .* (1 - u)^2;
}

If you don’t need these variables in the output, I’d suggest just making them local variables in the model block. I think they wind up all getting saved here because all these models were translated to Stan code from BUGS.

Also, I’d recommend moving from probit to logit because it’s much more efficient to apply inv_logit over a vector than Phi.

Are the standard normal priors on muhat and so on appropriate? They don’t get scaled or offset anywhere. What values do you get in the posterior?

Hi @Bob_Carpenter,

thanks a lot for these super informative answers!

Initial values
The initial values provided above for L_Omega were just meant to be an improvement on those in the case study to ensure that the initial values for L_Omega is a lower triangular matrix with positive diagonals. The default initial values might work as well. In rjags, I would have used different initial values for the mcmc chains to make sure that they start at different points, but perhaps the default initial values in rstan already do this.

(I cannot edit the first post anymore, but just noticed that given deltahat_tilde has the form [nparams,nsubjs], it should also be defined in this way in the initial values:
deltahat_tilde = matrix(rnorm(nsubjs * nparams), nparams, nsubjs). )

Divergent transitions plot
The plots were inspired by these very nice tutorials by Michael Betancourt:
https://betanalpha.github.io/assets/case_studies/identifiability.html
https://betanalpha.github.io/assets/case_studies/hierarchical_modeling.html

The green dots are the divergent transitions. I thought that perhaps they would not be as problematic as if they were located towards the neck in a funnel extending towards minus infinity.

Leapfrog Steps
A general problem with these hierachical mpt models seems to be that there is high autocorrelation in the mcmc chains and so the sampling is inefficient. For this reason, I have tried thinning the chains, but I eventually refrained from doing so because some point out that the autocorrelated chains may contain more information than the thinned chains (Doing Bayesian Data Analysis: Thinning to reduce autocorrelation: Rarely useful!) and unless the chains are saved for continued sampling, running long enough chain to thin them runs up against the time window for the HPC servers I am using. (I have read that CmdStan permits continuation of chains over multiple jobs, but so far I stick to rstan.)

Vectorization and logit
Thanks a lot! I would have to try out these options.

When attempting to implement the vectorized version, I encounter the following error which did not occur in the loop version of the model:


Error in stanc(file = file, model_code = model_code, model_name = model_name,  :
  0

Semantic error in 'string', line 39, column 2 to column 45:

Ill-typed arguments supplied to assignment operator =: lhs has type real[] and rhs has type vector

model = "
// Multinomial Processing Tree with Latent Traits
data { 
  int<lower=1> nsubjs; 
  int<lower=1> nparams; 
  int<lower=0,upper=3> k_1[nsubjs,3];
  int<lower=0,upper=3> k_2[nsubjs,3];
  int<lower=0,upper=3> k_3[nsubjs,3];
  int<lower=0,upper=3> k_4[nsubjs,3];
  int<lower=0,upper=3> k_5[nsubjs,3];
  int<lower=0,upper=3> k_6[nsubjs,3];
  int<lower=0,upper=3> k_7[nsubjs,3];
  int<lower=0,upper=3> k_8[nsubjs,3];
}

parameters {
  matrix[nparams, nsubjs] deltahat_tilde;
  
  cholesky_factor_corr[nparams] L_Omega; 
  vector<lower=0>[nparams] sigma; 

  real mu_c_hat;
  real mu_n_hat;
  real mu_S_hat;
  real mu_I_hat;
  real mu_ca_hat;
  real mu_na_hat; 
} 
transformed parameters {
  matrix[nsubjs,nparams] deltahat;
  
  vector<lower=0,upper=1>[nsubjs] C  = Phi(mu_c_hat  + deltahat[,1]);
  vector<lower=0,upper=1>[nsubjs] N  = Phi(mu_n_hat  + deltahat[,2]);
  vector<lower=0,upper=1>[nsubjs] S  = Phi(mu_S_hat  + deltahat[,3]);
  vector<lower=0,upper=1>[nsubjs] I  = Phi(mu_I_hat  + deltahat[,4]);
  vector<lower=0,upper=1>[nsubjs] Ca = Phi(mu_ca_hat + deltahat[,5]);
  vector<lower=0,upper=1>[nsubjs] Na = Phi(mu_na_hat + deltahat[,6]);
  
  deltahat = (diag_pre_multiply(sigma, L_Omega) * deltahat_tilde)'; 
  
  simplex[3] theta_1[nsubjs];
  theta_1[,1] = C+(1-C).*(1-N).*(1-S).*(1-I);
  theta_1[,2] = (1-C).*N+(1-C).*(1-N).*(1-S).*I;
  theta_1[,3] = (1-C).*(1-N).*S;
  simplex[3] theta_2[nsubjs];
  theta_2[,1] = (1-C).*(1-N).*(1-S).*(1-I);
  theta_2[,2] = C+(1-C).*N+(1-C).*(1-N).*(1-S).*I;
  theta_2[,3] = (1-C).*(1-N).*S;
  simplex[3] theta_3[nsubjs];
  theta_3[,1] = C+(1-C).*N+(1-C).*(1-N).*(1-S).*(1-I);
  theta_3[,2] = (1-C).*(1-N).*(1-S).*I;
  theta_3[,3] = (1-C).*(1-N).*S;
  simplex[3] theta_4[nsubjs];
  theta_4[,1] = (1-C).*N+(1-C).*(1-N).*(1-S).*(1-I);
  theta_4[,2] = C+(1-C).*(1-N).*(1-S).*I;
  theta_4[,3] = (1-C).*(1-N).*S; 
  simplex[3] theta_5[nsubjs];
  theta_5[,1] = Ca+(1-Ca).*(1-Na).*(1-S).*(1-I);
  theta_5[,2] = (1-Ca).*Na+(1-Ca).*(1-Na).*(1-S).*I;
  theta_5[,3] = (1-Ca).*(1-Na).*S;
  simplex[3] theta_6[nsubjs];
  theta_6[,1] = (1-Ca).*(1-Na).*(1-S).*(1-I);
  theta_6[,2] = Ca+(1-Ca).*Na+(1-Ca).*(1-Na).*(1-S).*I;
  theta_6[,3] = (1-Ca).*(1-Na).*S;
  simplex[3] theta_7[nsubjs];
  theta_7[,1] = Ca+(1-Ca).*Na+(1-Ca).*(1-Na).*(1-S).*(1-I);
  theta_7[,2] = (1-Ca).*(1-Na).*(1-S).*I;
  theta_7[,3] = (1-Ca).*(1-Na).*S;
  simplex[3] theta_8[nsubjs];
  theta_8[,1] = (1-Ca).*Na+(1-Ca).*(1-Na).*(1-S).*(1-I);
  theta_8[,2] = Ca+(1-Ca).*(1-Na).*(1-S).*I;
  theta_8[,3] = (1-Ca).*(1-Na).*S;
}
model {
  // Priors
  mu_c_hat   ~ normal(0, 1);
  mu_n_hat   ~ normal(0, 1);
  mu_S_hat   ~ normal(0, 1);
  mu_I_hat   ~ normal(0, 1);
  mu_ca_hat  ~ normal(0, 1);
  mu_na_hat  ~ normal(0, 1);
  
  L_Omega ~ lkj_corr_cholesky(4); 
  sigma ~ exponential(1); 
  to_vector(deltahat_tilde) ~ std_normal(); 

  // Data
  for (i in 1:nsubjs){
      k_1[i] ~ multinomial(theta_1[i]);
      k_2[i] ~ multinomial(theta_2[i]);
      k_3[i] ~ multinomial(theta_3[i]);
      k_4[i] ~ multinomial(theta_4[i]);
      k_5[i] ~ multinomial(theta_5[i]);
      k_6[i] ~ multinomial(theta_6[i]);
      k_7[i] ~ multinomial(theta_7[i]);
      k_8[i] ~ multinomial(theta_8[i]);
  }
}
generated quantities {

  real<lower=0,upper=1> mu_C;
  real<lower=0,upper=1> mu_N;
  real<lower=0,upper=1> mu_S;
  real<lower=0,upper=1> mu_I;
  real<lower=0,upper=1> mu_Ca;
  real<lower=0,upper=1> mu_Na;
  
  corr_matrix[nparams] Omega;

  // Post-Processing Means, Standard Deviations, Correlations
  mu_C     = Phi(mu_c_hat);
  mu_N     = Phi(mu_n_hat);
  mu_S     = Phi(mu_S_hat);
  mu_I     = Phi(mu_I_hat);
  mu_Ca    = Phi(mu_ca_hat);
  mu_Na    = Phi(mu_na_hat);
  
  Omega <- L_Omega * L_Omega';
  
  int<lower = 0, upper = 3> y_rep[nsubjs,24];
  real log_lik[nsubjs, 8];
  
	for(n in 1:nsubjs){
		y_rep[n,1:3]   = multinomial_rng(theta_1[n],3);
		y_rep[n,4:6]   = multinomial_rng(theta_2[n],3);
		y_rep[n,7:9]   = multinomial_rng(theta_3[n],3);
		y_rep[n,10:12] = multinomial_rng(theta_4[n],3);
		y_rep[n,13:15] = multinomial_rng(theta_5[n],3);
		y_rep[n,16:18] = multinomial_rng(theta_6[n],3);
		y_rep[n,19:21] = multinomial_rng(theta_7[n],3);
		y_rep[n,22:24] = multinomial_rng(theta_8[n],3);
		
		log_lik[n,1] =  multinomial_lpmf(k_1[n]| theta_1[n]);
		log_lik[n,2] =  multinomial_lpmf(k_2[n]| theta_2[n]);
		log_lik[n,3] =  multinomial_lpmf(k_3[n]| theta_3[n]);
		log_lik[n,4] =  multinomial_lpmf(k_4[n]| theta_4[n]);
		log_lik[n,5] =  multinomial_lpmf(k_5[n]| theta_5[n]);
		log_lik[n,6] =  multinomial_lpmf(k_6[n]| theta_6[n]);
		log_lik[n,7] =  multinomial_lpmf(k_7[n]| theta_7[n]);
		log_lik[n,8] =  multinomial_lpmf(k_8[n]| theta_8[n]);	
	}
}

"

Priors
For the particular data sets and mpt parameters that I am applying the case study model to - one could restrict some the priors for the mu parameters more than the standard normal. For instance, there are good reasons to restrict the S parameter.

Here are some posterior estimates from one data set (please note that these posterior parameter estimates have been transformed from the latent probit scale to the probability scale of the mpt parameters):

      HDI_lower Median HDI_upper
mu_C       0.09   0.11      0.14
mu_N       0.38   0.43      0.49
mu_S       0.21   0.25      0.29
mu_I       0.64   0.68      0.73
mu_Ca      0.23   0.26      0.29
mu_Na      0.30   0.35      0.40

Best,
Niels