Two-armed bandit hierarchical reinforcement learning model - interpreting conflicting loo and posterior predictive check results

Hi everyone,

Model and data

I fitted a hierarchical reinforcement learning model (a Pearce-Hall model according to Diederen et al.) to 2-armed bandit choice (0 or 1) data of 71 subjects. Each subject performed the task twice (2 conditions per subject), and each task run consisted of 50 trials.

Here’s my Stan model code which I wrote following the example by @Vanessa_Brown:

// input
data {
  int<lower=1> N; // number of subjects
  int<lower=1> C; // number of conditions (reinforcer_type)
  int<lower=1> T; // total number of trials across subjects
  int<lower=1> MT; // max number of trials / subject / condition
  
  int<lower=1, upper=MT> Tsubj[N, C]; // actual number of trials / subject / condition
  
  int<lower=-999, upper=1> choice[N, MT, C]; // choice of correct (1) or incorrect (0) card / trial / subject / condition
  int<lower=-999, upper=1> outcome[N, MT, C];  // outcome / trial / subject / condition
  
  int kS; // number of subj-level variables (aud_group)
  real subj_vars[N,kS]; //subj-level variable matrix (centered)
  int kV; // number of visit-level variables (reinforcer_type)
  real visit_vars[N,C,kV]; //visit-level variable matrix (centered) (X)
  
  int<lower = 0, upper = 1> run_estimation; // a switch to evaluate the likelihood
}

// transformed input 
transformed data {
  vector[2] initV;  // initial values for EV, both choices have a value of 0.5
  initV = rep_vector(0.5, 2);
  real initabsPE;
  initabsPE = 0.5; /// 
}

// output - posterior distribution should be sought
parameters {
  
  // Declare all parameters as vectors for vectorizing
  // hyperparameters (group-level means)
  vector[4] mu; // group level means for the 4 parameters

  // Subject-level raw parameters (fixed effect of aud group)
  vector[kS] A_sub_m;    // learning rate
  vector[kS] tau_sub_m;  // inverse temperature
  vector[kS] gamma_sub_m; // decay constant
  vector[kS] C_sub_m; // arbitrary constant
  
  // Condition-level raw parameters (fixed effect of reinforcer type)
  vector[kV] A_sub_con_m;    // learning rate
  vector[kV] tau_sub_con_m;  // inverse temperature
  vector[kV] gamma_sub_con_m;  // decay constant
  vector[kV] C_sub_con_m;  // arbitrary constant
  
  //cross-level interaction effects (fixed interaction effects)
  matrix[kV,kS] A_int_m;
  matrix[kV,kS] tau_int_m;
  matrix[kV,kS] gamma_int_m;
  matrix[kV,kS] C_int_m;
  
  //visit-level (within subject) SDs
  real<lower=0> A_visit_s;
  real<lower=0> tau_visit_s;
  real<lower=0> gamma_visit_s; 
  real<lower=0> C_visit_s; 
  
  //SDs of visit-level effects across subjects
  vector<lower=0>[kV+1] A_subj_s;
  vector<lower=0>[kV+1] tau_subj_s;
  vector<lower=0>[kV+1] gamma_subj_s;
  vector<lower=0>[kV+1] C_subj_s;
  
  //non-centered parameterization (ncp) variance effect per visit & subject
  matrix[N,C] A_visit_raw;
  matrix[N,C] tau_visit_raw;
  matrix[N,C] gamma_visit_raw;
  matrix[N,C] C_visit_raw;
  
  //NCP variance effect on subj-level effects
  matrix[kV+1,N] A_subj_raw;
  matrix[kV+1,N] tau_subj_raw;
  matrix[kV+1,N] gamma_subj_raw;
  matrix[kV+1,N] C_subj_raw;
  
  //Cholesky factors of correlation matrices for subj-level variances
  cholesky_factor_corr[kV+1] A_subj_L;
  cholesky_factor_corr[kV+1] tau_subj_L;
  cholesky_factor_corr[kV+1] gamma_subj_L;
  cholesky_factor_corr[kV+1] C_subj_L;
}

transformed parameters {
  
  // initialize condition-in-subject-level parameters
  matrix<lower=0, upper=1>[N,C] A; // bring alpha to range between 0 and 1
  matrix[N,C] A_normal; // alpha without range
  matrix<lower=0, upper=100>[N,C] tau; // bring tau to range between 0 and 100
  matrix[N,C] tau_normal; // tau without range
  matrix<lower=0, upper=1>[N,C] gamma; // bring gamma to range between 0 and 1
  matrix[N,C] gamma_normal; // gamma without range
  matrix<lower=0, upper=1>[N,C] C_const; // bring C to range between 0 and 1
  matrix[N,C] C_normal; // C without range
  
  //convert Cholesky factorized correlation matrix into SDs per visit-level effect (create random intercept and slope variances)
  matrix[N,kV+1] A_vars = (diag_pre_multiply(A_subj_s,A_subj_L)*A_subj_raw)';
  matrix[N,kV+1] tau_vars = (diag_pre_multiply(tau_subj_s,tau_subj_L)*tau_subj_raw)';
  matrix[N,kV+1] gamma_vars = (diag_pre_multiply(gamma_subj_s,gamma_subj_L)*gamma_subj_raw)';
  matrix[N,kV+1] C_vars = (diag_pre_multiply(C_subj_s,C_subj_L)*C_subj_raw)';
  
  //create transformed parameters using non-centered parameterization for all
  // and logistic transformation for alpha (range: 0 to 1),
  // add in subject and visit-level effects as shifts in means
  
  // compute subject-level parameters
  for (s in 1:N) {
    A_normal[s,]   = mu[1] + A_visit_s*A_visit_raw[s,] + A_vars[s,1]; // overall mean + visit-level variance effect + random intercept per subject
    tau_normal[s,] = mu[2] + tau_visit_s*tau_visit_raw[s,] + tau_vars[s,1];
    gamma_normal[s,] = mu[3] + gamma_visit_s*gamma_visit_raw[s,] + gamma_vars[s,1];
    C_normal[s,] = mu[4] + C_visit_s*C_visit_raw[s,] + C_vars[s,1];
    
    //add subj- and visit-level effects
    for (v in 1:C) { // for every condition
      
      for (kv in 1:kV) { 
        //main effects of visit-level variables
        A_normal[s,v] += visit_vars[s,v,kv]*(A_sub_con_m[kv]+A_vars[s,kv+1]); // predictor * fixed and random slope
        tau_normal[s,v] += visit_vars[s,v,kv]*(tau_sub_con_m[kv]+tau_vars[s,kv+1]);
        gamma_normal[s,v] += visit_vars[s,v,kv]*(gamma_sub_con_m[kv]+gamma_vars[s,kv+1]);
        C_normal[s,v] += visit_vars[s,v,kv]*(C_sub_con_m[kv]+C_vars[s,kv+1]);
        
        for (ks in 1:kS) { 
          //main effects of subject-level variables
          A_normal[s,v] += subj_vars[s,ks]*A_sub_m[ks]; // predictor * fixed slope
          tau_normal[s,v] += subj_vars[s,ks]*tau_sub_m[ks];
          gamma_normal[s,v] += subj_vars[s,ks]*gamma_sub_m[ks];
          C_normal[s,v] += subj_vars[s,ks]*C_sub_m[ks];
          
          //cross-level interactions
          A_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*A_int_m[ks,kv];
          tau_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*tau_int_m[ks,kv];
          gamma_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*gamma_int_m[ks,kv];
          C_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*C_int_m[ks,kv];
        }
        
      }
    
    }
    
    //transform to range [0,1] or [0,100]
    A[s,] = Phi_approx(A_normal[s,]);
    tau[s,] = Phi_approx(tau_normal[s,])*100;
    gamma[s,] = Phi_approx(gamma_normal[s,]);
    C_const[s,] = Phi_approx(C_normal[s,]);
    
  }
  
}

model {
  
  // define prior distributions
  
  // hyperparameters (group-level means)
  mu ~ normal(0, 1);
  
  // Subject-level raw parameters
  A_sub_m ~ normal(0, 1);
  tau_sub_m ~ normal(0, 1);
  gamma_sub_m ~ normal(0, 1);
  C_sub_m ~ normal(0, 1);
  
  // Condition-level raw parameters
  A_sub_con_m ~ normal(0,1);
  tau_sub_con_m ~ normal(0,1);
  gamma_sub_con_m ~ normal(0,1);
  C_sub_con_m ~ normal(0,1);
  
  // cross-level interactions
  for (ks in 1:kS) {
    A_int_m[,ks] ~ normal(0,1);
    tau_int_m[,ks] ~ normal(0,1);
    gamma_int_m[,ks] ~ normal(0,1);
    C_int_m[,ks] ~ normal(0,1);
  }
  
  //visit-level (within subject) SDs
  A_visit_s ~ cauchy(0,2);
  tau_visit_s ~ cauchy(0,2); 
  gamma_visit_s ~ cauchy(0,2);
  C_visit_s ~ cauchy(0,2); 
  
  //SDs of visit-level effects across subjects
  A_subj_s ~ student_t(3,0,2);
  tau_subj_s ~ student_t(3,0,3);
  gamma_subj_s ~ student_t(3,0,2);
  C_subj_s ~ student_t(3,0,2);
  
  for (s in 1:N) {
    //non-centered parameterization (ncp) variance effect per visit & subject
    A_visit_raw[s,] ~ normal(0,1);
    tau_visit_raw[s,] ~ normal(0,1);
    gamma_visit_raw[s,] ~ normal(0,1);
    C_visit_raw[s,] ~ normal(0,1);
    
    //NCP variance effect on subj-level effects
    to_vector(A_subj_raw[,s]) ~ normal(0,1);
    to_vector(tau_subj_raw[,s]) ~ normal(0,1);
    to_vector(gamma_subj_raw[,s]) ~ normal(0,1);
    to_vector(C_subj_raw[,s]) ~ normal(0,1);
  }
  
  //Cholesky factors of correlation matrices for subj-level variances
  // lkj distribution with shape parameter η = 1.0 is a uniform prior; set to 2 to 
  // imply no correlation between random intercepts and slopes (Sorensen & Vasishth, 2016) 
  A_subj_L ~ lkj_corr_cholesky(1);
  tau_subj_L ~ lkj_corr_cholesky(1);
  gamma_subj_L ~ lkj_corr_cholesky(1);
  C_subj_L ~ lkj_corr_cholesky(1);
  
  // only execute this part if we want to evaluate likelihood (fit real data)
  if (run_estimation==1){

    // subject loop
    for (s in 1:N) {
      
      // define needed variables
      vector[2] ev; // expected value for both options
      real PE;      // prediction error
      real absPE; // absolute prediction error
      real k; // learning rate per trial
    
      // condition loop
      for (v in 1:C) {
        
        // set initial values
        ev = initV;
        absPE = initabsPE;
        k = A[s,v];
      
        // trial loop
        for (t in 1:Tsubj[s,v]) {
        
          // how does choice relate to inverse temperature and action value
          choice[s,t,v] ~ bernoulli_logit(tau[s,v] * (ev[2]-ev[1])); // inverse temp * Q
          
          // Pearce Hall learning rate
          k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k; // decay constant * arbitrary constant * absolute PE from last trial + (1-decay constant) * learning rate from last trial
                                                    // if decay constant close to 1: dynamic learning rate will be strongly affected by PEs from last trial and only weakly affected by learning rate from previous trial (high fluctuation)
                                                    // if decay constant close to 0: dynamic learning rate will be weakly affected by PEs from last trial and strongly affected by learning rate from previous trial (low fluctuation)

        
          // prediction error
          PE = outcome[s,t,v] - ev[choice[s,t,v]+1]; // outcome - Q of choice taken
          absPE = abs(PE);
          
          // value updating (learning)
          ev[choice[s,t,v]+1] += k * PE; // Q + dynamic alpha * PE                                                                                    
      
        }
        
      }
    
    }
  
  }
  
}

generated quantities {
  
  // Define mean group-level parameter values
  real<lower=0, upper=1> mu_A; // initialize mean of posterior
  real<lower=0, upper=100> mu_tau;
  real<lower=0, upper=1> mu_gamma;
  real<lower=0, upper=1> mu_C;

  // For log likelihood calculation
  real log_lik[N,MT,C];
  
  // for choice propability calculation (of chosen option)
  real softmax_ev_chosen[N,MT,C];

  // For posterior predictive check
  int y_pred[N,MT,C];
  
  // extracting PEs per subject and trial
  real PE_pred[N,MT,C];
  
  // extracting q values per subject and trial
  real ev_pred[N,MT,C,2];
  real ev_chosen_pred[N,MT,C];
  
  // extracting dynamic learning rate per subject and trial
  real k_pred[N,MT,C];
  
  // correlation matrix
  corr_matrix[kV+1] A_cor = multiply_lower_tri_self_transpose(A_subj_L);
  corr_matrix[kV+1] tau_cor = multiply_lower_tri_self_transpose(tau_subj_L);
  corr_matrix[kV+1] gamma_cor = multiply_lower_tri_self_transpose(gamma_subj_L);
  corr_matrix[kV+1] C_cor = multiply_lower_tri_self_transpose(C_subj_L);

  // Set all PE and ev predictions to -999 (avoids NULL values)
  for (s in 1:N) {
    for (v in 1:C) {
      for (t in 1:MT) {
        y_pred[s,t,v] = -999;
        PE_pred[s,t,v] = -999;
        ev_chosen_pred[s,t,v] = -999;
        k_pred[s,t,v] = -999;
        softmax_ev_chosen[s,t,v] = -999;
        log_lik[s,t,v] = -999;
        for (c in 1:2) {
          ev_pred[s,t,v,c] = -999;
        }
      }
    }
  }
  
  
  // calculate overall means of parameters
  mu_A   = Phi_approx(mu[1]); 
  mu_tau = Phi_approx(mu[2]) * 100;
  mu_gamma   = Phi_approx(mu[3]);
  mu_C   = Phi_approx(mu[4]);
  

  { // local section, this saves time and space
    for (s in 1:N) {
      
      vector[2] ev; // expected value
      real PE;      // prediction error
      real absPE; // absolute prediction error
      real k; // learning rate
      vector[2] softmax_ev; // softmax per ev

      for (v in 1:C) {
        
        // initialize values
        ev = initV;
        absPE = initabsPE;
        k = A[s,v];
      
        // quantities of interest
        for (t in 1:Tsubj[s,v]) {
          
          // generate prediction for current trial
          // if estimation = 1, we draw from the posterior
          // if estimation = 0, we equally draw from the posterior, but the posterior is equal to the prior as likelihood is not evaluated
          y_pred[s,t,v] = bernoulli_logit_rng(tau[s,v] * (ev[2]-ev[1])); // following the recommendation to use the same function as in model block but with rng ending
          
          // if estimation = 1, compute quantities of interest based on actual choices
          if (run_estimation==1){
            
            // compute log likelihood of current trial
            log_lik[s,t,v] = bernoulli_logit_lpmf(choice[s,t,v] | tau[s,v] * (ev[2]-ev[1]));
            
            // compute choice probability
            softmax_ev = softmax(tau[s,v]*ev);
            
            softmax_ev_chosen[s,t,v] = softmax_ev[choice[s,t,v]+1];
            
            // Pearce Hall learning rate
            k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k;
            k_pred[s,t,v] = k;
            
            // prediction error
            PE = outcome[s,t,v] - ev[choice[s,t,v]+1];
            PE_pred[s,t,v] = PE;
          
            // value updating (learning)
            ev[choice[s,t,v]+1] += k * PE;
            
            ev_pred[s,t,v,1] = ev[1]; // copy both evs into pred
            ev_pred[s,t,v,2] = ev[2]; // copy both evs into pred
            
            ev_chosen_pred[s,t,v] = ev[choice[s,t,v]+1];
          
          }
        
          // if estimation = 0, compute quantities of interest based on simulated choices
          if (run_estimation==0){
          
            // compute log likelihood of current trial
            log_lik[s,t,v] = bernoulli_logit_lpmf(y_pred[s,t,v] | tau[s,v] * (ev[2]-ev[1]));
          
            // Pearce Hall learning rate
            k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k;
            k_pred[s,t,v] = k;
            
            // prediction error
            PE = outcome[s,t,v] - ev[y_pred[s,t,v]+1];
            PE_pred[s,t,v] = PE;
            
            // value updating (learning)
            ev[y_pred[s,t,v]+1] += k * PE;
          
            ev_pred[s,t,v,1] = ev[1]; // copy both evs into pred
            ev_pred[s,t,v,2] = ev[2]; // copy both evs into pred
            
            ev_chosen_pred[s,t,v] = ev[y_pred[s,t,v]+1];
          
          }
        
        } // trial loop
    
      } // condition loop
  
    } // subject loop

  } // local section
  
} // generated quiantities

Model fit evaluation

I am now trying to evaluate the model fit:

Posterior predictive checks show pretty good results, with the model being slightly too optimistic (predicting more correct choices than subjects actually made):




In this last plot, choice_p_correct indicates the actual percentage of correct choices per trial across all subjects and conditions, while mean_p_correct indicates its mean across all posterior draws.

Leave-one-out cross validation using loo by @avehtari , however, indicates that there are too many bad and very bad Pareto k values in order for elpd_loo to be trusted. I’m now trying to interpret this according to the loo vignette: As p_loo (= ) is > than the numper of parameters (=33), this seems to indicate bad model misspecification. However, PPC did not indicate strong model misspecification, even though it should likely do so, as the numper of parameters (n=33) is << than the number of observations (=71 subjects * 2 conditions * 50 trials).

Warning: Can't fit generalized Pareto distribution because all tail values are
the same.

Computed from 36000 by 6978 log-likelihood matrix

         Estimate   SE
elpd_loo  -2220.8 42.9
p_loo       197.8 14.2
looic      4441.6 85.8
------
Monte Carlo SE of elpd_loo is NA.

Pareto k diagnostic values:
                         Count Pct.    Min. n_eff
(-Inf, 0.5]   (good)     6528  93.6%   2816
 (0.5, 0.7]   (ok)        133   1.9%   439
   (0.7, 1]   (bad)       113   1.6%   42
   (1, Inf)   (very bad)  204   2.9%   12
See help('pareto-k-diagnostic') for details.

Questions

My questions are:

  1. Can the PPC results be interpreted as positively as I did?
  2. If so, how comes they stand in such contrast to loo? Does it make sense to use leave-one-out cross validation for single trials in my hierarchical model? Should I rather leave single conditions per subject or single subjects out?

Any help would be greatly appreciated!

Best,
Milena

I don’t understand how did you come up with 33, as these

  matrix[N,C] A_visit_raw;
  matrix[N,C] tau_visit_raw;
  matrix[N,C] gamma_visit_raw;
  matrix[N,C] C_visit_raw;
  
  //NCP variance effect on subj-level effects
  matrix[kV+1,N] A_subj_raw;
  matrix[kV+1,N] tau_subj_raw;
  matrix[kV+1,N] gamma_subj_raw;
  matrix[kV+1,N] C_subj_raw;

already have at least 8271=1136 parameters assuming kV=1

As you did not mention values for kS and kV, I’m not able to do full count, but I assume that the large number of high khats is due to very flexible model.

No. PPC’s you used are not useful for binary target as it is sufficient to have just one intercept parameter to get the proportions of two classes right. It would be better to use calibration or reliability plots as illustrated in Bayesian Logistic Regression with rstanarm

LOO is fine, but of course leave-one-group-out (LOGO) can match better your goals. LOGO will be computationally even more difficult with PSIS, but you could use K-fold-CV. It would be good to first understand, why PSIS-LOO is failing

Why did you get NULL values? If there are -999 values in log_lik, then LOO computation will be garbage.

1 Like

It appears that you’re calculating leave-one-out cross-validation (LOO) over individual trials/decisions. An important consideration is that the LOO algorithm typically assumes independent observations. In your data/analysis, what are conditionally independent are not the individual trials, but rather the 2*71 learning tasks. This suggests that LOO should be calculated from a log-likelihood matrix where each entry represents the product of trial-wise likelihoods (or the sum of log likelihoods), which could also help in reducing the influence of individual data points.

Influential trials may arise when decision-makers choose an option with an extremely low likelihood, or when predictions are highly accurate because participants consistently choose the same option. Averaging over all trials could mitigate the impact of the former scenario.

These are preliminary thoughts, but the key point is the importance of carefully considering which observations are independent and how this impacts the log_likelihood matrix generated for LOO.

1 Like

Actually no. See CV-FAQ: When is cross-validation valid? and related answers. Whether to leave out individual trials, learning tasks, or persons is a choice you can make depending on the modeling goals, and sometimes it’s useful to do all of those to focus on different parts.

But then leaving out more observations changes the posterior more, and importance sampling based PSIS-LOO is more likely to have problems. If the choice is to leave groups of data, it may be better to switch to K-fold-CV (where K can be big if you can parallelize efficiently, see, e.g. Bayesian cross-validation by parallel Markov chain Monte Carlo)

2 Likes

Thanks for the correction!
I should have looked up the docs before posting!

Hi @avehtari,

thanks for your reply! As you recommended, I am currently trying to understand why PSIS is failing before using k-fold CV as an alternative strategy.

As you did not mention values for kS and kV, I’m not able to do full count, but I assume that the large number of high khats is due to very flexible model.

You are right, I do have a lot more parameters than just 33 (kS and kV are 1, respectively). This was an error of thought on my side and indeed leads to a highly flexible model. I will try to simplify the model during the next couple days by leaving the random slope out and only keeping a random intercept.

PPC’s you used are not useful for binary target as it is sufficient to have just one intercept parameter to get the proportions of two classes right. It would be better to use calibration or reliability plots as illustrated in Bayesian Logistic Regression with rstanarm

Thanks for the reference! I used the recommended CORP approach by Dimitriadis, Gneiting, Jordan (2021) to create a calibration plot:

rd_PH_withC

# A tibble: 1 × 5
  forecast mean_score miscalibration discrimination uncertainty
  <chr>         <dbl>          <dbl>          <dbl>       <dbl>
1 EMOS         0.0996        0.00203         0.0537       0.151

I’m still having some issues interpreting the plot.

  • From what I understand, we see the binned model-predicted choice probability of choice = 1 per trial on the x-axis.
  • The y-axis then shows the observed choice probability of choice = 1 in trials that had a predicted choice probability contained in the respective bin on the x-axis.
  • In case that’s correct, the plot shows that for trials in which the predicted choice probability is around 55% or lower, the observed choice probability is lower than the predicted choice probability.
  • In less technical terms, in trials in which the model predicts that choice = 1 is unlikely (<50% choice probability), it actually is even more unlikely?
    Any feedback on my interpretation is welcome as I could be totally off.

Why did you get NULL values? If there are -999 values in log_lik , then LOO computation will be garbage.

My log_likelihood matrix initially includes -999 vectors as some participants did not make a choice in some of the 50 trials per condition. I manually exclude columns including -999 before calculating loo with the code pasted below. Does that make sense?

# extract log likelihood for each choice
  log_likelihood <- extract_log_lik(fit, parameter_name = "log_lik", merge_chains = TRUE)

  # exclude missing trials
  log_likelihood <- log_likelihood[,log_likelihood[1,]!=-999]
  
  # print and plot loo
  loo1 <- loo(log_likelihood)
  print(loo1)

Great thanks and best,
Milena

You interpreted the plot correctly.

Yes, you can do that. I assume you are excluding them also when computing the likelihood in the model?

Hi @avehtari ,

thanks for your feedback!

Yes, I also exclude them when computing the likelihood in the model, as the trial loop

for (t in 1:Tsubj[s,v])

only runs up until the individual numper of non -999 trials per subject and condition. -999 trials are always at the end of a trial sequence (e.g. trials, 48, 49, 50).

I did this and openend a new topic here. Any insights would be appreciated.

Best,
Milena