Remove/estimate correlation of parameters (or get to know if correlation is real)

Ah, got that now, thanks!

I reckon its a bit more stable, as the approx Phi is just a scaled inverse logit (see here Normal_lcdf underflows much faster than pnorm(log=TRUE) - #6 by Bob_Carpenter).


Here’s my go at it:

This model works fine for me (tested with cmdstanr):

data {
  int<lower=1> N;
  int<lower=1> T;
  int<lower=1, upper=T> Tsubj[N];
  real<lower=0> gain[N, T];
  real<lower=0> loss[N, T];
  real cert[N, T];
  int<lower=-1, upper=1> gamble[N, T];
}
parameters {
  vector[3] mu_pr;
  vector<lower=0>[3] sigma; 
  cholesky_factor_corr[3] cors;
  vector[3] rho_lambda_tau_tilde[N]; 
}
transformed parameters {
  vector<lower=0, upper=2>[N]  rho;
  vector<lower=0, upper=5>[N]  lambda;
  vector<lower=0, upper=30>[N] tau;
  cholesky_factor_cov[3] L_Sigma = diag_pre_multiply(sigma, cors);

  for (i in 1:N) {
    vector[3] logit_rho_lambda_tau_pr = inv_logit(mu_pr + L_Sigma*rho_lambda_tau_tilde[i]);
    rho[i]    = logit_rho_lambda_tau_pr[1] * 2;
    lambda[i] = logit_rho_lambda_tau_pr[2] * 5;
    tau[i]    = logit_rho_lambda_tau_pr[3] * 30;
  }
}
model {
  mu_pr ~ std_normal();
  sigma ~ normal(0, 0.2);
  cors ~ lkj_corr_cholesky(3);
  
  for( i in 1:N)
    rho_lambda_tau_tilde[i] ~ std_normal();

  for (i in 1:N) {
    for (t in 1:Tsubj[i]) {
      real evSafe;    // evSafe, evGamble, pGamble can be a scalar to save memory and increase speed.
      real evGamble;  // they are left as arrays as an example for RL models.

      // loss[i, t]=absolute amount of loss (pre-converted in R)
      evSafe   = pow(cert[i, t], rho[i]);
      evGamble = 0.5 * (pow(gain[i, t], rho[i]) - lambda[i] * pow(loss[i, t], rho[i]));
      gamble[i, t] ~ bernoulli_logit(tau[i] * (evGamble - evSafe));
    }
  }
}

This version is a bit faster (better “vectorization” of the bernoulli_logit):

data {
  int<lower=1> N;
  int<lower=1> T;
  int<lower=1, upper=T> Tsubj[N];
  real<lower=0> gain[N, T];
  real<lower=0> loss[N, T];
  real cert[N, T];
  int<lower=-1, upper=1> gamble[N, T];
}
parameters {
  vector[3] mu_pr;
  vector<lower=0>[3] sigma; 
  cholesky_factor_corr[3] cors;
  vector[3] rho_lambda_tau_tilde[N]; 
}
transformed parameters {
  vector<lower=0, upper=2>[N]  rho;
  vector<lower=0, upper=5>[N]  lambda;
  vector<lower=0, upper=30>[N] tau;
  cholesky_factor_cov[3] L_Sigma = diag_pre_multiply(sigma, cors);

  for (i in 1:N) {
    vector[3] logit_rho_lambda_tau_pr = inv_logit(mu_pr + L_Sigma*rho_lambda_tau_tilde[i]);
    rho[i]    = logit_rho_lambda_tau_pr[1] * 2;
    lambda[i] = logit_rho_lambda_tau_pr[2] * 5;
    tau[i]    = logit_rho_lambda_tau_pr[3] * 30;
  }
  
}
model {
  
  mu_pr ~ std_normal();
  sigma ~ normal(0, 0.2);
  cors ~ lkj_corr_cholesky(3);
  
  for (i in 1:N){
    
    vector[Tsubj[i]] evSafe;
    vector[Tsubj[i]] evGamble;

    for (t in 1:Tsubj[i]){
      evSafe[t] = pow(cert[i, t], rho[i]);
      evGamble[t] = 0.5 * (pow(gain[i, t], rho[i]) - lambda[i] * pow(loss[i, t], rho[i]));
    }
    
    head(gamble[i], Tsubj[i]) ~ bernoulli_logit(tau[i] * (evGamble - evSafe));
    rho_lambda_tau_tilde[i] ~ std_normal();
    
  }
  
}

You might have to fix the generated quantities block, which I have ignored here.

Results

3 chains, 600 warm-up each, 600 post-warmup iterations each

   variable    mean median     sd    mad      q5     q95   rhat ess_bulk ess_tail
   <chr>      <dbl>  <dbl>  <dbl>  <dbl>   <dbl>   <dbl>  <dbl>    <dbl>    <dbl>
 1 mu_pr[1]  -0.150 -0.158 0.113  0.109  -0.328   0.0412  1.00      690.     904.
 2 mu_pr[2]  -1.66  -1.65  0.0716 0.0698 -1.77   -1.54    0.999     991.    1079.
 3 mu_pr[3]  -3.60  -3.60  0.237  0.243  -4.00   -3.22    1.00      846.    1057.
 4 sigma[1]   0.444  0.441 0.0581 0.0592  0.353   0.545   1.00     1236.    1397.
 5 sigma[2]   0.221  0.218 0.0542 0.0518  0.136   0.311   1.00      906.     702.
 6 sigma[3]   0.902  0.901 0.0977 0.0996  0.742   1.06    1.01     1035.    1318.
 7 cors[1,1]  1      1     0      0       1       1      NA          NA       NA 
 8 cors[2,1]  0.283  0.298 0.177  0.177  -0.0285  0.550   1.00     1317.    1249.
 9 cors[3,1] -0.703 -0.714 0.0943 0.0903 -0.834  -0.537   1.00     1177.    1317.
10 cors[1,2]  0      0     0      0       0       0      NA          NA       NA 

Cheers!
Max

2 Likes