# Two-armed bandit hierarchical reinforcement learning model - interpreting conflicting loo and posterior predictive check results

Hi everyone,

Model and data

I fitted a hierarchical reinforcement learning model (a Pearce-Hall model according to Diederen et al.) to 2-armed bandit choice (0 or 1) data of 71 subjects. Each subject performed the task twice (2 conditions per subject), and each task run consisted of 50 trials.

Here’s my Stan model code which I wrote following the example by @Vanessa_Brown:

``````// input
data {
int<lower=1> N; // number of subjects
int<lower=1> C; // number of conditions (reinforcer_type)
int<lower=1> T; // total number of trials across subjects
int<lower=1> MT; // max number of trials / subject / condition

int<lower=1, upper=MT> Tsubj[N, C]; // actual number of trials / subject / condition

int<lower=-999, upper=1> choice[N, MT, C]; // choice of correct (1) or incorrect (0) card / trial / subject / condition
int<lower=-999, upper=1> outcome[N, MT, C];  // outcome / trial / subject / condition

int kS; // number of subj-level variables (aud_group)
real subj_vars[N,kS]; //subj-level variable matrix (centered)
int kV; // number of visit-level variables (reinforcer_type)
real visit_vars[N,C,kV]; //visit-level variable matrix (centered) (X)

int<lower = 0, upper = 1> run_estimation; // a switch to evaluate the likelihood
}

// transformed input
transformed data {
vector[2] initV;  // initial values for EV, both choices have a value of 0.5
initV = rep_vector(0.5, 2);
real initabsPE;
initabsPE = 0.5; ///
}

// output - posterior distribution should be sought
parameters {

// Declare all parameters as vectors for vectorizing
// hyperparameters (group-level means)
vector[4] mu; // group level means for the 4 parameters

// Subject-level raw parameters (fixed effect of aud group)
vector[kS] A_sub_m;    // learning rate
vector[kS] tau_sub_m;  // inverse temperature
vector[kS] gamma_sub_m; // decay constant
vector[kS] C_sub_m; // arbitrary constant

// Condition-level raw parameters (fixed effect of reinforcer type)
vector[kV] A_sub_con_m;    // learning rate
vector[kV] tau_sub_con_m;  // inverse temperature
vector[kV] gamma_sub_con_m;  // decay constant
vector[kV] C_sub_con_m;  // arbitrary constant

//cross-level interaction effects (fixed interaction effects)
matrix[kV,kS] A_int_m;
matrix[kV,kS] tau_int_m;
matrix[kV,kS] gamma_int_m;
matrix[kV,kS] C_int_m;

//visit-level (within subject) SDs
real<lower=0> A_visit_s;
real<lower=0> tau_visit_s;
real<lower=0> gamma_visit_s;
real<lower=0> C_visit_s;

//SDs of visit-level effects across subjects
vector<lower=0>[kV+1] A_subj_s;
vector<lower=0>[kV+1] tau_subj_s;
vector<lower=0>[kV+1] gamma_subj_s;
vector<lower=0>[kV+1] C_subj_s;

//non-centered parameterization (ncp) variance effect per visit & subject
matrix[N,C] A_visit_raw;
matrix[N,C] tau_visit_raw;
matrix[N,C] gamma_visit_raw;
matrix[N,C] C_visit_raw;

//NCP variance effect on subj-level effects
matrix[kV+1,N] A_subj_raw;
matrix[kV+1,N] tau_subj_raw;
matrix[kV+1,N] gamma_subj_raw;
matrix[kV+1,N] C_subj_raw;

//Cholesky factors of correlation matrices for subj-level variances
cholesky_factor_corr[kV+1] A_subj_L;
cholesky_factor_corr[kV+1] tau_subj_L;
cholesky_factor_corr[kV+1] gamma_subj_L;
cholesky_factor_corr[kV+1] C_subj_L;
}

transformed parameters {

// initialize condition-in-subject-level parameters
matrix<lower=0, upper=1>[N,C] A; // bring alpha to range between 0 and 1
matrix[N,C] A_normal; // alpha without range
matrix<lower=0, upper=100>[N,C] tau; // bring tau to range between 0 and 100
matrix[N,C] tau_normal; // tau without range
matrix<lower=0, upper=1>[N,C] gamma; // bring gamma to range between 0 and 1
matrix[N,C] gamma_normal; // gamma without range
matrix<lower=0, upper=1>[N,C] C_const; // bring C to range between 0 and 1
matrix[N,C] C_normal; // C without range

//convert Cholesky factorized correlation matrix into SDs per visit-level effect (create random intercept and slope variances)
matrix[N,kV+1] A_vars = (diag_pre_multiply(A_subj_s,A_subj_L)*A_subj_raw)';
matrix[N,kV+1] tau_vars = (diag_pre_multiply(tau_subj_s,tau_subj_L)*tau_subj_raw)';
matrix[N,kV+1] gamma_vars = (diag_pre_multiply(gamma_subj_s,gamma_subj_L)*gamma_subj_raw)';
matrix[N,kV+1] C_vars = (diag_pre_multiply(C_subj_s,C_subj_L)*C_subj_raw)';

//create transformed parameters using non-centered parameterization for all
// and logistic transformation for alpha (range: 0 to 1),
// add in subject and visit-level effects as shifts in means

// compute subject-level parameters
for (s in 1:N) {
A_normal[s,]   = mu[1] + A_visit_s*A_visit_raw[s,] + A_vars[s,1]; // overall mean + visit-level variance effect + random intercept per subject
tau_normal[s,] = mu[2] + tau_visit_s*tau_visit_raw[s,] + tau_vars[s,1];
gamma_normal[s,] = mu[3] + gamma_visit_s*gamma_visit_raw[s,] + gamma_vars[s,1];
C_normal[s,] = mu[4] + C_visit_s*C_visit_raw[s,] + C_vars[s,1];

for (v in 1:C) { // for every condition

for (kv in 1:kV) {
//main effects of visit-level variables
A_normal[s,v] += visit_vars[s,v,kv]*(A_sub_con_m[kv]+A_vars[s,kv+1]); // predictor * fixed and random slope
tau_normal[s,v] += visit_vars[s,v,kv]*(tau_sub_con_m[kv]+tau_vars[s,kv+1]);
gamma_normal[s,v] += visit_vars[s,v,kv]*(gamma_sub_con_m[kv]+gamma_vars[s,kv+1]);
C_normal[s,v] += visit_vars[s,v,kv]*(C_sub_con_m[kv]+C_vars[s,kv+1]);

for (ks in 1:kS) {
//main effects of subject-level variables
A_normal[s,v] += subj_vars[s,ks]*A_sub_m[ks]; // predictor * fixed slope
tau_normal[s,v] += subj_vars[s,ks]*tau_sub_m[ks];
gamma_normal[s,v] += subj_vars[s,ks]*gamma_sub_m[ks];
C_normal[s,v] += subj_vars[s,ks]*C_sub_m[ks];

//cross-level interactions
A_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*A_int_m[ks,kv];
tau_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*tau_int_m[ks,kv];
gamma_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*gamma_int_m[ks,kv];
C_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*C_int_m[ks,kv];
}

}

}

//transform to range [0,1] or [0,100]
A[s,] = Phi_approx(A_normal[s,]);
tau[s,] = Phi_approx(tau_normal[s,])*100;
gamma[s,] = Phi_approx(gamma_normal[s,]);
C_const[s,] = Phi_approx(C_normal[s,]);

}

}

model {

// define prior distributions

// hyperparameters (group-level means)
mu ~ normal(0, 1);

// Subject-level raw parameters
A_sub_m ~ normal(0, 1);
tau_sub_m ~ normal(0, 1);
gamma_sub_m ~ normal(0, 1);
C_sub_m ~ normal(0, 1);

// Condition-level raw parameters
A_sub_con_m ~ normal(0,1);
tau_sub_con_m ~ normal(0,1);
gamma_sub_con_m ~ normal(0,1);
C_sub_con_m ~ normal(0,1);

// cross-level interactions
for (ks in 1:kS) {
A_int_m[,ks] ~ normal(0,1);
tau_int_m[,ks] ~ normal(0,1);
gamma_int_m[,ks] ~ normal(0,1);
C_int_m[,ks] ~ normal(0,1);
}

//visit-level (within subject) SDs
A_visit_s ~ cauchy(0,2);
tau_visit_s ~ cauchy(0,2);
gamma_visit_s ~ cauchy(0,2);
C_visit_s ~ cauchy(0,2);

//SDs of visit-level effects across subjects
A_subj_s ~ student_t(3,0,2);
tau_subj_s ~ student_t(3,0,3);
gamma_subj_s ~ student_t(3,0,2);
C_subj_s ~ student_t(3,0,2);

for (s in 1:N) {
//non-centered parameterization (ncp) variance effect per visit & subject
A_visit_raw[s,] ~ normal(0,1);
tau_visit_raw[s,] ~ normal(0,1);
gamma_visit_raw[s,] ~ normal(0,1);
C_visit_raw[s,] ~ normal(0,1);

//NCP variance effect on subj-level effects
to_vector(A_subj_raw[,s]) ~ normal(0,1);
to_vector(tau_subj_raw[,s]) ~ normal(0,1);
to_vector(gamma_subj_raw[,s]) ~ normal(0,1);
to_vector(C_subj_raw[,s]) ~ normal(0,1);
}

//Cholesky factors of correlation matrices for subj-level variances
// lkj distribution with shape parameter η = 1.0 is a uniform prior; set to 2 to
// imply no correlation between random intercepts and slopes (Sorensen & Vasishth, 2016)
A_subj_L ~ lkj_corr_cholesky(1);
tau_subj_L ~ lkj_corr_cholesky(1);
gamma_subj_L ~ lkj_corr_cholesky(1);
C_subj_L ~ lkj_corr_cholesky(1);

// only execute this part if we want to evaluate likelihood (fit real data)
if (run_estimation==1){

// subject loop
for (s in 1:N) {

// define needed variables
vector[2] ev; // expected value for both options
real PE;      // prediction error
real absPE; // absolute prediction error
real k; // learning rate per trial

// condition loop
for (v in 1:C) {

// set initial values
ev = initV;
absPE = initabsPE;
k = A[s,v];

// trial loop
for (t in 1:Tsubj[s,v]) {

// how does choice relate to inverse temperature and action value
choice[s,t,v] ~ bernoulli_logit(tau[s,v] * (ev[2]-ev[1])); // inverse temp * Q

// Pearce Hall learning rate
k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k; // decay constant * arbitrary constant * absolute PE from last trial + (1-decay constant) * learning rate from last trial
// if decay constant close to 1: dynamic learning rate will be strongly affected by PEs from last trial and only weakly affected by learning rate from previous trial (high fluctuation)
// if decay constant close to 0: dynamic learning rate will be weakly affected by PEs from last trial and strongly affected by learning rate from previous trial (low fluctuation)

// prediction error
PE = outcome[s,t,v] - ev[choice[s,t,v]+1]; // outcome - Q of choice taken
absPE = abs(PE);

// value updating (learning)
ev[choice[s,t,v]+1] += k * PE; // Q + dynamic alpha * PE

}

}

}

}

}

generated quantities {

// Define mean group-level parameter values
real<lower=0, upper=1> mu_A; // initialize mean of posterior
real<lower=0, upper=100> mu_tau;
real<lower=0, upper=1> mu_gamma;
real<lower=0, upper=1> mu_C;

// For log likelihood calculation
real log_lik[N,MT,C];

// for choice propability calculation (of chosen option)
real softmax_ev_chosen[N,MT,C];

// For posterior predictive check
int y_pred[N,MT,C];

// extracting PEs per subject and trial
real PE_pred[N,MT,C];

// extracting q values per subject and trial
real ev_pred[N,MT,C,2];
real ev_chosen_pred[N,MT,C];

// extracting dynamic learning rate per subject and trial
real k_pred[N,MT,C];

// correlation matrix
corr_matrix[kV+1] A_cor = multiply_lower_tri_self_transpose(A_subj_L);
corr_matrix[kV+1] tau_cor = multiply_lower_tri_self_transpose(tau_subj_L);
corr_matrix[kV+1] gamma_cor = multiply_lower_tri_self_transpose(gamma_subj_L);
corr_matrix[kV+1] C_cor = multiply_lower_tri_self_transpose(C_subj_L);

// Set all PE and ev predictions to -999 (avoids NULL values)
for (s in 1:N) {
for (v in 1:C) {
for (t in 1:MT) {
y_pred[s,t,v] = -999;
PE_pred[s,t,v] = -999;
ev_chosen_pred[s,t,v] = -999;
k_pred[s,t,v] = -999;
softmax_ev_chosen[s,t,v] = -999;
log_lik[s,t,v] = -999;
for (c in 1:2) {
ev_pred[s,t,v,c] = -999;
}
}
}
}

// calculate overall means of parameters
mu_A   = Phi_approx(mu[1]);
mu_tau = Phi_approx(mu[2]) * 100;
mu_gamma   = Phi_approx(mu[3]);
mu_C   = Phi_approx(mu[4]);

{ // local section, this saves time and space
for (s in 1:N) {

vector[2] ev; // expected value
real PE;      // prediction error
real absPE; // absolute prediction error
real k; // learning rate
vector[2] softmax_ev; // softmax per ev

for (v in 1:C) {

// initialize values
ev = initV;
absPE = initabsPE;
k = A[s,v];

// quantities of interest
for (t in 1:Tsubj[s,v]) {

// generate prediction for current trial
// if estimation = 1, we draw from the posterior
// if estimation = 0, we equally draw from the posterior, but the posterior is equal to the prior as likelihood is not evaluated
y_pred[s,t,v] = bernoulli_logit_rng(tau[s,v] * (ev[2]-ev[1])); // following the recommendation to use the same function as in model block but with rng ending

// if estimation = 1, compute quantities of interest based on actual choices
if (run_estimation==1){

// compute log likelihood of current trial
log_lik[s,t,v] = bernoulli_logit_lpmf(choice[s,t,v] | tau[s,v] * (ev[2]-ev[1]));

// compute choice probability
softmax_ev = softmax(tau[s,v]*ev);

softmax_ev_chosen[s,t,v] = softmax_ev[choice[s,t,v]+1];

// Pearce Hall learning rate
k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k;
k_pred[s,t,v] = k;

// prediction error
PE = outcome[s,t,v] - ev[choice[s,t,v]+1];
PE_pred[s,t,v] = PE;

// value updating (learning)
ev[choice[s,t,v]+1] += k * PE;

ev_pred[s,t,v,1] = ev[1]; // copy both evs into pred
ev_pred[s,t,v,2] = ev[2]; // copy both evs into pred

ev_chosen_pred[s,t,v] = ev[choice[s,t,v]+1];

}

// if estimation = 0, compute quantities of interest based on simulated choices
if (run_estimation==0){

// compute log likelihood of current trial
log_lik[s,t,v] = bernoulli_logit_lpmf(y_pred[s,t,v] | tau[s,v] * (ev[2]-ev[1]));

// Pearce Hall learning rate
k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k;
k_pred[s,t,v] = k;

// prediction error
PE = outcome[s,t,v] - ev[y_pred[s,t,v]+1];
PE_pred[s,t,v] = PE;

// value updating (learning)
ev[y_pred[s,t,v]+1] += k * PE;

ev_pred[s,t,v,1] = ev[1]; // copy both evs into pred
ev_pred[s,t,v,2] = ev[2]; // copy both evs into pred

ev_chosen_pred[s,t,v] = ev[y_pred[s,t,v]+1];

}

} // trial loop

} // condition loop

} // subject loop

} // local section

} // generated quiantities
``````

Model fit evaluation

I am now trying to evaluate the model fit:

Posterior predictive checks show pretty good results, with the model being slightly too optimistic (predicting more correct choices than subjects actually made):

In this last plot, choice_p_correct indicates the actual percentage of correct choices per trial across all subjects and conditions, while mean_p_correct indicates its mean across all posterior draws.

Leave-one-out cross validation using loo by @avehtari , however, indicates that there are too many bad and very bad Pareto k values in order for elpd_loo to be trusted. I’m now trying to interpret this according to the loo vignette: As p_loo (= ) is > than the numper of parameters (=33), this seems to indicate bad model misspecification. However, PPC did not indicate strong model misspecification, even though it should likely do so, as the numper of parameters (n=33) is << than the number of observations (=71 subjects * 2 conditions * 50 trials).

``````Warning: Can't fit generalized Pareto distribution because all tail values are
the same.

Computed from 36000 by 6978 log-likelihood matrix

Estimate   SE
elpd_loo  -2220.8 42.9
p_loo       197.8 14.2
looic      4441.6 85.8
------
Monte Carlo SE of elpd_loo is NA.

Pareto k diagnostic values:
Count Pct.    Min. n_eff
(-Inf, 0.5]   (good)     6528  93.6%   2816
(0.5, 0.7]   (ok)        133   1.9%   439
(0.7, 1]   (bad)       113   1.6%   42
(1, Inf)   (very bad)  204   2.9%   12
See help('pareto-k-diagnostic') for details.

``````

Questions

My questions are:

1. Can the PPC results be interpreted as positively as I did?
2. If so, how comes they stand in such contrast to loo? Does it make sense to use leave-one-out cross validation for single trials in my hierarchical model? Should I rather leave single conditions per subject or single subjects out?

Any help would be greatly appreciated!

Best,
Milena

I don’t understand how did you come up with 33, as these

``````  matrix[N,C] A_visit_raw;
matrix[N,C] tau_visit_raw;
matrix[N,C] gamma_visit_raw;
matrix[N,C] C_visit_raw;

//NCP variance effect on subj-level effects
matrix[kV+1,N] A_subj_raw;
matrix[kV+1,N] tau_subj_raw;
matrix[kV+1,N] gamma_subj_raw;
matrix[kV+1,N] C_subj_raw;
``````

already have at least 8271=1136 parameters assuming kV=1

As you did not mention values for kS and kV, I’m not able to do full count, but I assume that the large number of high khats is due to very flexible model.

No. PPC’s you used are not useful for binary target as it is sufficient to have just one intercept parameter to get the proportions of two classes right. It would be better to use calibration or reliability plots as illustrated in Bayesian Logistic Regression with rstanarm

LOO is fine, but of course leave-one-group-out (LOGO) can match better your goals. LOGO will be computationally even more difficult with PSIS, but you could use K-fold-CV. It would be good to first understand, why PSIS-LOO is failing

Why did you get NULL values? If there are -999 values in `log_lik`, then LOO computation will be garbage.

1 Like

It appears that you’re calculating leave-one-out cross-validation (LOO) over individual trials/decisions. An important consideration is that the LOO algorithm typically assumes independent observations. In your data/analysis, what are conditionally independent are not the individual trials, but rather the 2*71 learning tasks. This suggests that LOO should be calculated from a log-likelihood matrix where each entry represents the product of trial-wise likelihoods (or the sum of log likelihoods), which could also help in reducing the influence of individual data points.

Influential trials may arise when decision-makers choose an option with an extremely low likelihood, or when predictions are highly accurate because participants consistently choose the same option. Averaging over all trials could mitigate the impact of the former scenario.

These are preliminary thoughts, but the key point is the importance of carefully considering which observations are independent and how this impacts the log_likelihood matrix generated for LOO.

1 Like

Actually no. See CV-FAQ: When is cross-validation valid? and related answers. Whether to leave out individual trials, learning tasks, or persons is a choice you can make depending on the modeling goals, and sometimes it’s useful to do all of those to focus on different parts.

But then leaving out more observations changes the posterior more, and importance sampling based PSIS-LOO is more likely to have problems. If the choice is to leave groups of data, it may be better to switch to K-fold-CV (where K can be big if you can parallelize efficiently, see, e.g. Bayesian cross-validation by parallel Markov chain Monte Carlo)

2 Likes

Thanks for the correction!
I should have looked up the docs before posting!

Hi @avehtari,

thanks for your reply! As you recommended, I am currently trying to understand why PSIS is failing before using k-fold CV as an alternative strategy.

As you did not mention values for kS and kV, I’m not able to do full count, but I assume that the large number of high khats is due to very flexible model.

You are right, I do have a lot more parameters than just 33 (kS and kV are 1, respectively). This was an error of thought on my side and indeed leads to a highly flexible model. I will try to simplify the model during the next couple days by leaving the random slope out and only keeping a random intercept.

PPC’s you used are not useful for binary target as it is sufficient to have just one intercept parameter to get the proportions of two classes right. It would be better to use calibration or reliability plots as illustrated in Bayesian Logistic Regression with rstanarm

Thanks for the reference! I used the recommended CORP approach by Dimitriadis, Gneiting, Jordan (2021) to create a calibration plot:

``````# A tibble: 1 × 5
forecast mean_score miscalibration discrimination uncertainty
<chr>         <dbl>          <dbl>          <dbl>       <dbl>
1 EMOS         0.0996        0.00203         0.0537       0.151
``````

I’m still having some issues interpreting the plot.

• From what I understand, we see the binned model-predicted choice probability of choice = 1 per trial on the x-axis.
• The y-axis then shows the observed choice probability of choice = 1 in trials that had a predicted choice probability contained in the respective bin on the x-axis.
• In case that’s correct, the plot shows that for trials in which the predicted choice probability is around 55% or lower, the observed choice probability is lower than the predicted choice probability.
• In less technical terms, in trials in which the model predicts that choice = 1 is unlikely (<50% choice probability), it actually is even more unlikely?
Any feedback on my interpretation is welcome as I could be totally off.

Why did you get NULL values? If there are -999 values in `log_lik` , then LOO computation will be garbage.

My log_likelihood matrix initially includes -999 vectors as some participants did not make a choice in some of the 50 trials per condition. I manually exclude columns including -999 before calculating loo with the code pasted below. Does that make sense?

``````# extract log likelihood for each choice
log_likelihood <- extract_log_lik(fit, parameter_name = "log_lik", merge_chains = TRUE)

# exclude missing trials
log_likelihood <- log_likelihood[,log_likelihood[1,]!=-999]

# print and plot loo
loo1 <- loo(log_likelihood)
print(loo1)
``````

Great thanks and best,
Milena

You interpreted the plot correctly.

Yes, you can do that. I assume you are excluding them also when computing the likelihood in the model?

Hi @avehtari ,

``````for (t in 1:Tsubj[s,v])