Two-armed bandit hierarchical reinforcement learning model - simplify to random intercept only

Hi there,

I made some progress on answering the open questions myself, but am still unsure if my current random intercept model does what I think it does. Here my updates:

I think the hierarchical model part I adapted from @Vanessa_Brown 's code in my previous post can be translated into the following lme4 code:

alpha ~ 1 + group*condition + (1 + condition | subject) + (1 | visit)
tau ~  1 + group*condition + (1 + condition | subject) + (1 | visit)
gamma ~ 1 + group*condition + (1 + condition | subject) + (1 | visit)
C ~ 1 + group*condition + (1 + condition | subject) + (1 | visit)
  • As condition (0,1) is equivalent to the visit number (0,1) in my case and I am interested in the effect of condition, including a random intercept per condition/visit makes no sense for my purposes.

  • As I only have one observation per condition and subject, estimating a random slope for condition per subject also makes no sense.

The model I want is thus still:

alpha ~ 1 + group*condition + (1 | subject)
tau ~ 1 + group*condition + (1 | subject)
gamma ~ 1 + group*condition + (1 | subject)
C ~ 1 + group*condition + (1 | subject)

I corrected my previous Stan implementation of this model according to the really helpful examples by Julian Faraway. Specifically, the random effect per subject (e.g. A_subj_raw) is now multiplied by the variance of intercepts across subjects (e.g. A_subj_s).

// input
data {
  int<lower=1> N; // number of subjects
  int<lower=1> C; // number of visits
  int<lower=1> T; // total number of trials (observations) across subjects
  int<lower=1> MT; // max number of trials / subject / condition
  
  int<lower=1, upper=MT> Tsubj[N, C]; // actual number of trials / subject / condition
  
  int<lower=-999, upper=1> choice[N, MT, C]; // choice of correct (1) or incorrect (0) card / trial / subject / condition
  int<lower=-999, upper=1> outcome[N, MT, C];  // outcome / trial / subject / condition
  
  int kS; // number of subj-level predictors (aud_group)
  real subj_vars[N,kS]; //subj-level variable matrix (centered) - aud_group per subject
  int kV; // number of visit-level predictors (reinforcer_type)
  real visit_vars[N,C,kV]; //visit-level variable matrix (centered) (renforcer_type per subject and visit)
  
  int<lower = 0, upper = 1> run_estimation; // a switch to evaluate the likelihood
}

// transformed input 
transformed data {
  vector[2] initV;  // initial values for EV, both choices have a value of 0.5
  initV = rep_vector(0.5, 2);
  real initabsPE;
  initabsPE = 0.5; /// 
}

// output - posterior distribution should be sought
parameters {
  
  // Declare all parameters as vectors for vectorizing
  // hyperparameters (group-level means)
  vector[4] mu; // fixed intercepts for the 4 parameters, for HC and juice (these are coded as 0)

  // Subject-level raw parameters (fixed slope of aud group)
  vector[kS] A_sub_m;    // learning rate
  vector[kS] tau_sub_m;  // inverse temperature
  vector[kS] gamma_sub_m; // decay constant
  vector[kS] C_sub_m; // arbitrary constant
  
  // Condition-level raw parameters (fixed slope of reinforcer type)
  vector[kV] A_sub_con_m;    // learning rate
  vector[kV] tau_sub_con_m;  // inverse temperature
  vector[kV] gamma_sub_con_m;  // decay constant
  vector[kV] C_sub_con_m;  // arbitrary constant
  
  //cross-level interaction effects (fixed interaction effects)
  matrix[kV,kS] A_int_m;
  matrix[kV,kS] tau_int_m;
  matrix[kV,kS] gamma_int_m;
  matrix[kV,kS] C_int_m;
  
  //within subject SD (variance of random intercepts across subjects)
  real<lower=0> A_subj_s;
  real<lower=0> tau_subj_s;
  real<lower=0> gamma_subj_s;
  real<lower=0> C_subj_s;
  
  //non-centered parameterization effect on subj-level
  vector[N] A_subj_raw;
  vector[N] tau_subj_raw;
  vector[N] gamma_subj_raw;
  vector[N] C_subj_raw;
  
}

transformed parameters {
  
  // initialize condition-in-subject-level parameters
  matrix<lower=0, upper=1>[N,C] A; // bring alpha to range between 0 and 1
  matrix[N,C] A_normal; // raw alpha without range
  matrix<lower=0, upper=100>[N,C] tau; // bring tau to range between 0 and 100
  matrix[N,C] tau_normal; // tau without range
  matrix<lower=0, upper=1>[N,C] gamma; // bring gamma to range between 0 and 1
  matrix[N,C] gamma_normal; // gamma without range
  matrix<lower=0, upper=1>[N,C] C_const; // bring C to range between 0 and 1
  matrix[N,C] C_normal; // C without range

  //random intercepts per subject
  vector[N] A_vars = A_subj_s*A_subj_raw;
  vector[N] tau_vars = tau_subj_s*tau_subj_raw;
  vector[N] gamma_vars = gamma_subj_s*gamma_subj_raw;
  vector[N] C_vars = C_subj_s*C_subj_raw;
  
  // subject loop
  for (s in 1:N) {
    
    // condition loop
    for (v in 1:C) { // for every condition
      
      // fixed and random intercepts
      A_normal[s,v] = mu[1] + A_vars[s]; // fixed intercept + random intercept per subject
      tau_normal[s,v] = mu[2] + tau_vars[s];
      gamma_normal[s,v] = mu[3] + gamma_vars[s];
      C_normal[s,v] = mu[4] + C_vars[s];
      
      for (kv in 1:kV) { 
        //fixed effects of visit-level variables
        A_normal[s,v] += visit_vars[s,v,kv]*A_sub_con_m[kv]; // predictor * fixed slope
        tau_normal[s,v] += visit_vars[s,v,kv]*tau_sub_con_m[kv];
        gamma_normal[s,v] += visit_vars[s,v,kv]*gamma_sub_con_m[kv];
        C_normal[s,v] += visit_vars[s,v,kv]*C_sub_con_m[kv];
        
        for (ks in 1:kS) { 
          //fixed effects of subject-level variables
          A_normal[s,v] += subj_vars[s,ks]*A_sub_m[ks]; // predictor * fixed slope
          tau_normal[s,v] += subj_vars[s,ks]*tau_sub_m[ks];
          gamma_normal[s,v] += subj_vars[s,ks]*gamma_sub_m[ks];
          C_normal[s,v] += subj_vars[s,ks]*C_sub_m[ks];
          
          //fixed cross-level interactions
          A_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*A_int_m[ks,kv]; // predictor * predictor * fixed slope
          tau_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*tau_int_m[ks,kv];
          gamma_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*gamma_int_m[ks,kv];
          C_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*C_int_m[ks,kv];
        }
        
      }
    
    }
    
    //transform to range [0,1] or [0,100]
    A[s,] = Phi_approx(A_normal[s,]);
    tau[s,] = Phi_approx(tau_normal[s,])*100;
    gamma[s,] = Phi_approx(gamma_normal[s,]);
    C_const[s,] = Phi_approx(C_normal[s,]);

  }
  
}


model {
  
  // define prior distributions
  
  // hyperparameters (group-level means)
  mu ~ normal(0, 1);
  
  // Subject-level raw parameters
  A_sub_m ~ normal(0, 1);
  tau_sub_m ~ normal(0, 1);
  gamma_sub_m ~ normal(0, 1);
  C_sub_m ~ normal(0, 1);
  
  // Condition-level raw parameters
  A_sub_con_m ~ normal(0,1);
  tau_sub_con_m ~ normal(0,1);
  gamma_sub_con_m ~ normal(0,1);
  C_sub_con_m ~ normal(0,1);
  
  // cross-level interactions
  for (ks in 1:kS) {
    A_int_m[,ks] ~ normal(0,1);
    tau_int_m[,ks] ~ normal(0,1);
    gamma_int_m[,ks] ~ normal(0,1);
    C_int_m[,ks] ~ normal(0,1);
  }
  
  //SDs of visit-level effects across subjects
  A_subj_s ~ cauchy(0,2);
  tau_subj_s ~ cauchy(0,2);
  gamma_subj_s ~ cauchy(0,2);
  C_subj_s ~ cauchy(0,2);
    
  //NCP variance effect on subj-level effects
  A_subj_raw ~ normal(0,1);
  tau_subj_raw ~ normal(0,1);
  gamma_subj_raw ~ normal(0,1);
  C_subj_raw ~ normal(0,1);
  
  // only execute this part if we want to evaluate likelihood (fit real data)
  if (run_estimation==1){

    // subject loop
    for (s in 1:N) {
      
      // define needed variables
      vector[2] ev; // expected value for both options
      real PE;      // prediction error
      real absPE; // absolute prediction error
      real k; // learning rate per trial
    
      // condition loop
      for (v in 1:C) {
        
        // set initial values
        ev = initV;
        absPE = initabsPE;
        k = A[s,v];
      
        // trial loop
        for (t in 1:Tsubj[s,v]) {
        
          // how does choice relate to inverse temperature and action value
          choice[s,t,v] ~ bernoulli_logit(tau[s,v] * (ev[2]-ev[1])); // inverse temp * Q
          
          // Pearce Hall learning rate
          k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k; // decay constant * arbitrary constant * absolute PE from last trial + (1-decay constant) * learning rate from last trial
                                                    // if decay constant close to 1: dynamic learning rate will be strongly affected by PEs from last trial and only weakly affected by learning rate from previous trial (high fluctuation)
                                                    // if decay constant close to 0: dynamic learning rate will be weakly affected by PEs from last trial and strongly affected by learning rate from previous trial (low fluctuation)

        
          // prediction error
          PE = outcome[s,t,v] - ev[choice[s,t,v]+1]; // outcome - Q of choice taken
          absPE = abs(PE);
          
          // value updating (learning)
          ev[choice[s,t,v]+1] += k * PE; // Q + dynamic alpha * PE                                                                                    
      
        }
        
      }
    
    }
  
  }
  
}

When fitting this model to my data, chains do not mix (high R-hat) and I get warnings about bulk and tail ESS. Before considering if this model actually is no good description of the data-generating process, I want to make sure:

Does the Stan code do what I think it does, i.e. is the Stan code a valid implementation of the lme4 code posted above?

How could I reparameterize the model to make sampling more efficient, especially since I am already using non-centered parameterization?

Any help would be appreciated.

Best,
Milena