Hi there,
Data
as mentioned in my previous post, I want to fit a hierarchical reinforcement learning (RL) model (a Pearce-Hall model according to Diederen et al. ) to 2-armed bandit choice (0/1) data of 71 subjects.
Each subject belongs to one out of two groups (group = 0/1) and performed the task twice (condition = 1/2). Each task run consisted of 50 trials.
Modeling goal
The code should perform two things at once:
- Run a trial-by-trial reinforcement learning model (in the Stan model block) to estimate the RL parameters (alpha, tau, C, gamma) for each subject and condition.
- Run a hierarchical linear model with the RL parameters (alpha, tau, C, gamma) per subject and condition as dependent variables.
Random-slope-and-intercept model
To translate the hierarchical linear model to Stan, I started from an example by @Vanessa_Brown.
Question 1: How can the hierarchical linear model part from @Vanessa_Brown 's code be written in lme4? From what I understand, the model includes per subject random slopes and possibly also a visit nesting variable (?).
I adapted the example model to my data (see my previous post), but it seems to be overly flexible.
Random-intercept-only model
Now, I want to reduce the model to a random-intercept-only model with group (0/1), condition (0/1), and their interaction as fixed effects, and with a random intercept per subject. In lme4, I would write the hierarchical linear model part as shown below. Lme4 would give me estimates for the fixed intercept, the three fixed slopes, the random intercept per subject, and the residual variance.
alpha ~ 1 + group*condition + (1 | subject)
tau ~ 1 + group*condition + (1 | subject)
gamma ~ 1 + group*condition + (1 | subject)
C ~ 1 + group*condition + (1 | subject)
This is my try at building the equivalent hierarchical model around the RL model part in Stan, following Sorensen & Vasishth (2016):
// MILENA MUSIAL 12/2023
// input
data {
int<lower=1> N; // number of subjects
int<lower=1> C; // number of conditions (reinforcer_type)
int<lower=1> T; // total number of trials across subjects
int<lower=1> MT; // max number of trials / subject / condition
int<lower=1, upper=MT> Tsubj[N, C]; // actual number of trials / subject / condition
int<lower=-999, upper=1> choice[N, MT, C]; // choice of correct (1) or incorrect (0) card / trial / subject / condition
int<lower=-999, upper=1> outcome[N, MT, C]; // outcome / trial / subject / condition
int kS; // number of subj-level variables (aud_group)
real subj_vars[N,kS]; //subj-level variable matrix (centered)
int kV; // number of visit-level variables (reinforcer_type)
real visit_vars[N,C,kV]; //visit-level variable matrix (centered) (X)
int<lower = 0, upper = 1> run_estimation; // a switch to evaluate the likelihood
}
// transformed input
transformed data {
vector[2] initV; // initial values for EV, both choices have a value of 0.5
initV = rep_vector(0.5, 2);
real initabsPE;
initabsPE = 0.5; ///
}
// output - posterior distribution should be sought
parameters {
// Declare all parameters as vectors for vectorizing
// hyperparameters (group-level means)
vector[4] mu; // group level (fixed) intercepts for the 4 parameters, for HC and juice (these are coded as 0)
// Subject-level raw parameters (fixed slope of aud group)
vector[kS] A_sub_m; // learning rate
vector[kS] tau_sub_m; // inverse temperature
vector[kS] gamma_sub_m; // decay constant
vector[kS] C_sub_m; // arbitrary constant
// Condition-level raw parameters (fixed slope of reinforcer type)
vector[kV] A_sub_con_m; // learning rate
vector[kV] tau_sub_con_m; // inverse temperature
vector[kV] gamma_sub_con_m; // decay constant
vector[kV] C_sub_con_m; // arbitrary constant
//cross-level interaction effects (fixed interaction effects)
matrix[kV,kS] A_int_m;
matrix[kV,kS] tau_int_m;
matrix[kV,kS] gamma_int_m;
matrix[kV,kS] C_int_m;
//subject random intercepts
vector[N] A_vars;
vector[N] tau_vars;
vector[N] gamma_vars;
vector[N] C_vars;
//subject SDs used to calculate subject random intercepts
real<lower=0> A_subj_s;
real<lower=0> tau_subj_s;
real<lower=0> gamma_subj_s;
real<lower=0> C_subj_s;
}
transformed parameters {
// initialize condition-in-subject-level parameters
matrix<lower=0, upper=1>[N,C] A; // bring alpha to range between 0 and 1
matrix[N,C] A_normal; // alpha without range
matrix<lower=0, upper=100>[N,C] tau; // bring tau to range between 0 and 100
matrix[N,C] tau_normal; // tau without range
matrix<lower=0, upper=1>[N,C] gamma; // bring gamma to range between 0 and 1
matrix[N,C] gamma_normal; // gamma without range
matrix<lower=0, upper=1>[N,C] C_const; // bring C to range between 0 and 1
matrix[N,C] C_normal; // C without range
//create transformed parameters using non-centered parameterization for all
// and logistic transformation for alpha (range: 0 to 1),
// add in subject and visit-level effects as shifts in means
// subject loop
for (s in 1:N) {
// condition loop
for (v in 1:C) { // for every condition
// fixed and random intercepts
A_normal[s,v] = mu[1] + A_vars[s]; // fixed intercept + random intercept per subject
tau_normal[s,v] = mu[2] + tau_vars[s];
gamma_normal[s,v] = mu[3] + gamma_vars[s];
C_normal[s,v] = mu[4] + C_vars[s];
for (kv in 1:kV) {
//fixed effects of visit-level variables
A_normal[s,v] += visit_vars[s,v,kv]*A_sub_con_m[kv]; // predictor * fixed slope
tau_normal[s,v] += visit_vars[s,v,kv]*tau_sub_con_m[kv];
gamma_normal[s,v] += visit_vars[s,v,kv]*gamma_sub_con_m[kv];
C_normal[s,v] += visit_vars[s,v,kv]*C_sub_con_m[kv];
for (ks in 1:kS) {
//fixed effects of subject-level variables
A_normal[s,v] += subj_vars[s,ks]*A_sub_m[ks]; // predictor * fixed slope
tau_normal[s,v] += subj_vars[s,ks]*tau_sub_m[ks];
gamma_normal[s,v] += subj_vars[s,ks]*gamma_sub_m[ks];
C_normal[s,v] += subj_vars[s,ks]*C_sub_m[ks];
//fixed cross-level interactions
A_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*A_int_m[ks,kv]; // predictor * predictor * fixed slope
tau_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*tau_int_m[ks,kv];
gamma_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*gamma_int_m[ks,kv];
C_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*C_int_m[ks,kv];
}
}
}
//transform to range [0,1] or [0,100]
A[s,] = Phi_approx(A_normal[s,]);
tau[s,] = Phi_approx(tau_normal[s,])*100;
gamma[s,] = Phi_approx(gamma_normal[s,]);
C_const[s,] = Phi_approx(C_normal[s,]);
}
}
model {
// define prior distributions
// hyperparameters (group-level means)
mu ~ normal(0, 1);
// Subject-level raw parameters
A_sub_m ~ normal(0, 1);
tau_sub_m ~ normal(0, 1);
gamma_sub_m ~ normal(0, 1);
C_sub_m ~ normal(0, 1);
// Condition-level raw parameters
A_sub_con_m ~ normal(0,1);
tau_sub_con_m ~ normal(0,1);
gamma_sub_con_m ~ normal(0,1);
C_sub_con_m ~ normal(0,1);
// cross-level interactions
for (ks in 1:kS) {
A_int_m[,ks] ~ normal(0,1);
tau_int_m[,ks] ~ normal(0,1);
gamma_int_m[,ks] ~ normal(0,1);
C_int_m[,ks] ~ normal(0,1);
}
//Subject random intercepts
A_vars ~ normal(0,A_subj_s);
tau_vars ~ normal(0,tau_subj_s);
gamma_vars ~ normal(0,gamma_subj_s);
C_vars ~ normal(0,C_subj_s);
// only execute this part if we want to evaluate likelihood (fit real data)
if (run_estimation==1){
// subject loop
for (s in 1:N) {
// define needed variables
vector[2] ev; // expected value for both options
real PE; // prediction error
real absPE; // absolute prediction error
real k; // learning rate per trial
// condition loop
for (v in 1:C) {
// set initial values
ev = initV;
absPE = initabsPE;
k = A[s,v];
// trial loop
for (t in 1:Tsubj[s,v]) {
// how does choice relate to inverse temperature and action value
choice[s,t,v] ~ bernoulli_logit(tau[s,v] * (ev[2]-ev[1])); // inverse temp * Q
// Pearce Hall learning rate
k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k; // decay constant * arbitrary constant * absolute PE from last trial + (1-decay constant) * learning rate from last trial
// if decay constant close to 1: dynamic learning rate will be strongly affected by PEs from last trial and only weakly affected by learning rate from previous trial (high fluctuation)
// if decay constant close to 0: dynamic learning rate will be weakly affected by PEs from last trial and strongly affected by learning rate from previous trial (low fluctuation)
// prediction error
PE = outcome[s,t,v] - ev[choice[s,t,v]+1]; // outcome - Q of choice taken
absPE = abs(PE);
// value updating (learning)
ev[choice[s,t,v]+1] += k * PE; // Q + dynamic alpha * PE
}
}
}
}
}
generated quantities {
// Define mean group-level parameter values
real<lower=0, upper=1> mu_A; // initialize mean of posterior
real<lower=0, upper=100> mu_tau;
real<lower=0, upper=1> mu_gamma;
real<lower=0, upper=1> mu_C;
// For log likelihood calculation
real log_lik[N,MT,C];
// for choice propability calculation (of chosen option)
real softmax_ev_chosen[N,MT,C];
// For posterior predictive check
int y_pred[N,MT,C];
// extracting PEs per subject and trial
real PE_pred[N,MT,C];
// extracting q values per subject and trial
real ev_pred[N,MT,C,2];
real ev_chosen_pred[N,MT,C];
// extracting dynamic learning rate per subject and trial
real k_pred[N,MT,C];
// Set all PE and ev predictions to -999 (avoids NULL values)
for (s in 1:N) {
for (v in 1:C) {
for (t in 1:MT) {
y_pred[s,t,v] = -999;
PE_pred[s,t,v] = -999;
ev_chosen_pred[s,t,v] = -999;
k_pred[s,t,v] = -999;
softmax_ev_chosen[s,t,v] = -999;
log_lik[s,t,v] = -999;
for (c in 1:2) {
ev_pred[s,t,v,c] = -999;
}
}
}
}
// calculate overall means of parameters
mu_A = Phi_approx(mu[1]);
mu_tau = Phi_approx(mu[2]) * 100;
mu_gamma = Phi_approx(mu[3]);
mu_C = Phi_approx(mu[4]);
{ // local section, this saves time and space
for (s in 1:N) {
vector[2] ev; // expected value
real PE; // prediction error
real absPE; // absolute prediction error
real k; // learning rate
vector[2] softmax_ev; // softmax per ev
for (v in 1:C) {
// initialize values
ev = initV;
absPE = initabsPE;
k = A[s,v];
// quantities of interest
for (t in 1:Tsubj[s,v]) {
// generate prediction for current trial
// if estimation = 1, we draw from the posterior
// if estimation = 0, we equally draw from the posterior, but the posterior is equal to the prior as likelihood is not evaluated
y_pred[s,t,v] = bernoulli_logit_rng(tau[s,v] * (ev[2]-ev[1])); // following the recommendation to use the same function as in model block but with rng ending
// if estimation = 1, compute quantities of interest based on actual choices
if (run_estimation==1){
// compute log likelihood of current trial
log_lik[s,t,v] = bernoulli_logit_lpmf(choice[s,t,v] | tau[s,v] * (ev[2]-ev[1]));
// compute choice probability
softmax_ev = softmax(tau[s,v]*ev);
softmax_ev_chosen[s,t,v] = softmax_ev[choice[s,t,v]+1];
// Pearce Hall learning rate
k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k;
k_pred[s,t,v] = k;
// prediction error
PE = outcome[s,t,v] - ev[choice[s,t,v]+1];
PE_pred[s,t,v] = PE;
// value updating (learning)
ev[choice[s,t,v]+1] += k * PE;
ev_pred[s,t,v,1] = ev[1]; // copy both evs into pred
ev_pred[s,t,v,2] = ev[2]; // copy both evs into pred
ev_chosen_pred[s,t,v] = ev[choice[s,t,v]+1];
}
// if estimation = 0, compute quantities of interest based on simulated choices
if (run_estimation==0){
// compute log likelihood of current trial
log_lik[s,t,v] = bernoulli_logit_lpmf(y_pred[s,t,v] | tau[s,v] * (ev[2]-ev[1]));
// Pearce Hall learning rate
k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k;
k_pred[s,t,v] = k;
// prediction error
PE = outcome[s,t,v] - ev[y_pred[s,t,v]+1];
PE_pred[s,t,v] = PE;
// value updating (learning)
ev[y_pred[s,t,v]+1] += k * PE;
ev_pred[s,t,v,1] = ev[1]; // copy both evs into pred
ev_pred[s,t,v,2] = ev[2]; // copy both evs into pred
ev_chosen_pred[s,t,v] = ev[y_pred[s,t,v]+1];
}
} // trial loop
} // condition loop
} // subject loop
} // local section
} // generated quantities
Question 2: Is the hierarchical linear model part equivalent to the lme4 code above? Does the residual variance have to be defined as a separate parameter?
Taking alpha (A) as an example, the components should map as follows:
- mu[1] = fixed intercept
- A_sub_m = fixed effect of group
- A_sub_con_m = fixed effect of condition
- A_int_m = fixed interaction effect
- A_vars = random intercept per subject
This post is related to some of @carinaufer 's and @simondesch 's posts Any feedback from someone familiar with frequentist hierarchical linear models would be appreciated.
Best,
Milena