Two-armed bandit hierarchical reinforcement learning model - simplify to random intercept only

Hi there,

Data

as mentioned in my previous post, I want to fit a hierarchical reinforcement learning (RL) model (a Pearce-Hall model according to Diederen et al. ) to 2-armed bandit choice (0/1) data of 71 subjects.
Each subject belongs to one out of two groups (group = 0/1) and performed the task twice (condition = 1/2). Each task run consisted of 50 trials.

Modeling goal

The code should perform two things at once:

  1. Run a trial-by-trial reinforcement learning model (in the Stan model block) to estimate the RL parameters (alpha, tau, C, gamma) for each subject and condition.
  2. Run a hierarchical linear model with the RL parameters (alpha, tau, C, gamma) per subject and condition as dependent variables.

Random-slope-and-intercept model

To translate the hierarchical linear model to Stan, I started from an example by @Vanessa_Brown.

Question 1: How can the hierarchical linear model part from @Vanessa_Brown 's code be written in lme4? From what I understand, the model includes per subject random slopes and possibly also a visit nesting variable (?).

I adapted the example model to my data (see my previous post), but it seems to be overly flexible.

Random-intercept-only model

Now, I want to reduce the model to a random-intercept-only model with group (0/1), condition (0/1), and their interaction as fixed effects, and with a random intercept per subject. In lme4, I would write the hierarchical linear model part as shown below. Lme4 would give me estimates for the fixed intercept, the three fixed slopes, the random intercept per subject, and the residual variance.

alpha ~ 1 + group*condition + (1 | subject)
tau ~ 1 + group*condition + (1 | subject)
gamma ~ 1 + group*condition + (1 | subject)
C ~ 1 + group*condition + (1 | subject)

This is my try at building the equivalent hierarchical model around the RL model part in Stan, following Sorensen & Vasishth (2016):

// MILENA MUSIAL 12/2023

// input
data {
  int<lower=1> N; // number of subjects
  int<lower=1> C; // number of conditions (reinforcer_type)
  int<lower=1> T; // total number of trials across subjects
  int<lower=1> MT; // max number of trials / subject / condition
  
  int<lower=1, upper=MT> Tsubj[N, C]; // actual number of trials / subject / condition
  
  int<lower=-999, upper=1> choice[N, MT, C]; // choice of correct (1) or incorrect (0) card / trial / subject / condition
  int<lower=-999, upper=1> outcome[N, MT, C];  // outcome / trial / subject / condition
  
  int kS; // number of subj-level variables (aud_group)
  real subj_vars[N,kS]; //subj-level variable matrix (centered)
  int kV; // number of visit-level variables (reinforcer_type)
  real visit_vars[N,C,kV]; //visit-level variable matrix (centered) (X)
  
  int<lower = 0, upper = 1> run_estimation; // a switch to evaluate the likelihood
}

// transformed input 
transformed data {
  vector[2] initV;  // initial values for EV, both choices have a value of 0.5
  initV = rep_vector(0.5, 2);
  real initabsPE;
  initabsPE = 0.5; /// 
}

// output - posterior distribution should be sought
parameters {
  
  // Declare all parameters as vectors for vectorizing
  // hyperparameters (group-level means)
  vector[4] mu; // group level (fixed) intercepts for the 4 parameters, for HC and juice (these are coded as 0)

  // Subject-level raw parameters (fixed slope of aud group)
  vector[kS] A_sub_m;    // learning rate
  vector[kS] tau_sub_m;  // inverse temperature
  vector[kS] gamma_sub_m; // decay constant
  vector[kS] C_sub_m; // arbitrary constant
  
  // Condition-level raw parameters (fixed slope of reinforcer type)
  vector[kV] A_sub_con_m;    // learning rate
  vector[kV] tau_sub_con_m;  // inverse temperature
  vector[kV] gamma_sub_con_m;  // decay constant
  vector[kV] C_sub_con_m;  // arbitrary constant
  
  //cross-level interaction effects (fixed interaction effects)
  matrix[kV,kS] A_int_m;
  matrix[kV,kS] tau_int_m;
  matrix[kV,kS] gamma_int_m;
  matrix[kV,kS] C_int_m;
  
  //subject random intercepts
  vector[N] A_vars;
  vector[N] tau_vars;
  vector[N] gamma_vars;
  vector[N] C_vars;
  
  //subject SDs used to calculate subject random intercepts
  real<lower=0> A_subj_s;
  real<lower=0> tau_subj_s;
  real<lower=0> gamma_subj_s;
  real<lower=0> C_subj_s;
  
}

transformed parameters {
  
  // initialize condition-in-subject-level parameters
  matrix<lower=0, upper=1>[N,C] A; // bring alpha to range between 0 and 1
  matrix[N,C] A_normal; // alpha without range
  matrix<lower=0, upper=100>[N,C] tau; // bring tau to range between 0 and 100
  matrix[N,C] tau_normal; // tau without range
  matrix<lower=0, upper=1>[N,C] gamma; // bring gamma to range between 0 and 1
  matrix[N,C] gamma_normal; // gamma without range
  matrix<lower=0, upper=1>[N,C] C_const; // bring C to range between 0 and 1
  matrix[N,C] C_normal; // C without range

  //create transformed parameters using non-centered parameterization for all
  // and logistic transformation for alpha (range: 0 to 1),
  // add in subject and visit-level effects as shifts in means
  
  // subject loop
  for (s in 1:N) {
    
    // condition loop
    for (v in 1:C) { // for every condition
    
      // fixed and random intercepts
      A_normal[s,v] = mu[1] + A_vars[s]; // fixed intercept + random intercept per subject
      tau_normal[s,v] = mu[2] + tau_vars[s];
      gamma_normal[s,v] = mu[3] + gamma_vars[s];
      C_normal[s,v] = mu[4] + C_vars[s];
      
      for (kv in 1:kV) { 
        //fixed effects of visit-level variables
        A_normal[s,v] += visit_vars[s,v,kv]*A_sub_con_m[kv]; // predictor * fixed slope
        tau_normal[s,v] += visit_vars[s,v,kv]*tau_sub_con_m[kv];
        gamma_normal[s,v] += visit_vars[s,v,kv]*gamma_sub_con_m[kv];
        C_normal[s,v] += visit_vars[s,v,kv]*C_sub_con_m[kv];
        
        for (ks in 1:kS) { 
          //fixed effects of subject-level variables
          A_normal[s,v] += subj_vars[s,ks]*A_sub_m[ks]; // predictor * fixed slope
          tau_normal[s,v] += subj_vars[s,ks]*tau_sub_m[ks];
          gamma_normal[s,v] += subj_vars[s,ks]*gamma_sub_m[ks];
          C_normal[s,v] += subj_vars[s,ks]*C_sub_m[ks];
          
          //fixed cross-level interactions
          A_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*A_int_m[ks,kv]; // predictor * predictor * fixed slope
          tau_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*tau_int_m[ks,kv];
          gamma_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*gamma_int_m[ks,kv];
          C_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*C_int_m[ks,kv];
        }
        
      }
    
    }
    
    //transform to range [0,1] or [0,100]
    A[s,] = Phi_approx(A_normal[s,]);
    tau[s,] = Phi_approx(tau_normal[s,])*100;
    gamma[s,] = Phi_approx(gamma_normal[s,]);
    C_const[s,] = Phi_approx(C_normal[s,]);
    
  }
  
}


model {
  
  // define prior distributions
  
  // hyperparameters (group-level means)
  mu ~ normal(0, 1);
  
  // Subject-level raw parameters
  A_sub_m ~ normal(0, 1);
  tau_sub_m ~ normal(0, 1);
  gamma_sub_m ~ normal(0, 1);
  C_sub_m ~ normal(0, 1);
  
  // Condition-level raw parameters
  A_sub_con_m ~ normal(0,1);
  tau_sub_con_m ~ normal(0,1);
  gamma_sub_con_m ~ normal(0,1);
  C_sub_con_m ~ normal(0,1);
  
  // cross-level interactions
  for (ks in 1:kS) {
    A_int_m[,ks] ~ normal(0,1);
    tau_int_m[,ks] ~ normal(0,1);
    gamma_int_m[,ks] ~ normal(0,1);
    C_int_m[,ks] ~ normal(0,1);
  }
  
  //Subject random intercepts
  A_vars ~ normal(0,A_subj_s);
  tau_vars ~ normal(0,tau_subj_s);
  gamma_vars ~ normal(0,gamma_subj_s);
  C_vars ~ normal(0,C_subj_s);

  // only execute this part if we want to evaluate likelihood (fit real data)
  if (run_estimation==1){

    // subject loop
    for (s in 1:N) {
      
      // define needed variables
      vector[2] ev; // expected value for both options
      real PE;      // prediction error
      real absPE; // absolute prediction error
      real k; // learning rate per trial
    
      // condition loop
      for (v in 1:C) {
        
        // set initial values
        ev = initV;
        absPE = initabsPE;
        k = A[s,v];
      
        // trial loop
        for (t in 1:Tsubj[s,v]) {
        
          // how does choice relate to inverse temperature and action value
          choice[s,t,v] ~ bernoulli_logit(tau[s,v] * (ev[2]-ev[1])); // inverse temp * Q
          
          // Pearce Hall learning rate
          k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k; // decay constant * arbitrary constant * absolute PE from last trial + (1-decay constant) * learning rate from last trial
                                                    // if decay constant close to 1: dynamic learning rate will be strongly affected by PEs from last trial and only weakly affected by learning rate from previous trial (high fluctuation)
                                                    // if decay constant close to 0: dynamic learning rate will be weakly affected by PEs from last trial and strongly affected by learning rate from previous trial (low fluctuation)

        
          // prediction error
          PE = outcome[s,t,v] - ev[choice[s,t,v]+1]; // outcome - Q of choice taken
          absPE = abs(PE);
          
          // value updating (learning)
          ev[choice[s,t,v]+1] += k * PE; // Q + dynamic alpha * PE                                                                                    
      
        }
        
      }
    
    }
  
  }
  
}

generated quantities {
  
  // Define mean group-level parameter values
  real<lower=0, upper=1> mu_A; // initialize mean of posterior
  real<lower=0, upper=100> mu_tau;
  real<lower=0, upper=1> mu_gamma;
  real<lower=0, upper=1> mu_C;

  // For log likelihood calculation
  real log_lik[N,MT,C];
  
  // for choice propability calculation (of chosen option)
  real softmax_ev_chosen[N,MT,C];

  // For posterior predictive check
  int y_pred[N,MT,C];
  
  // extracting PEs per subject and trial
  real PE_pred[N,MT,C];
  
  // extracting q values per subject and trial
  real ev_pred[N,MT,C,2];
  real ev_chosen_pred[N,MT,C];
  
  // extracting dynamic learning rate per subject and trial
  real k_pred[N,MT,C];

  // Set all PE and ev predictions to -999 (avoids NULL values)
  for (s in 1:N) {
    for (v in 1:C) {
      for (t in 1:MT) {
        y_pred[s,t,v] = -999;
        PE_pred[s,t,v] = -999;
        ev_chosen_pred[s,t,v] = -999;
        k_pred[s,t,v] = -999;
        softmax_ev_chosen[s,t,v] = -999;
        log_lik[s,t,v] = -999;
        for (c in 1:2) {
          ev_pred[s,t,v,c] = -999;
        }
      }
    }
  }
  
  // calculate overall means of parameters
  mu_A   = Phi_approx(mu[1]); 
  mu_tau = Phi_approx(mu[2]) * 100;
  mu_gamma   = Phi_approx(mu[3]);
  mu_C   = Phi_approx(mu[4]);

  { // local section, this saves time and space
    for (s in 1:N) {
      
      vector[2] ev; // expected value
      real PE;      // prediction error
      real absPE; // absolute prediction error
      real k; // learning rate
      vector[2] softmax_ev; // softmax per ev

      for (v in 1:C) {
        
        // initialize values
        ev = initV;
        absPE = initabsPE;
        k = A[s,v];
      
        // quantities of interest
        for (t in 1:Tsubj[s,v]) {
          
          // generate prediction for current trial
          // if estimation = 1, we draw from the posterior
          // if estimation = 0, we equally draw from the posterior, but the posterior is equal to the prior as likelihood is not evaluated
          y_pred[s,t,v] = bernoulli_logit_rng(tau[s,v] * (ev[2]-ev[1])); // following the recommendation to use the same function as in model block but with rng ending
          
          // if estimation = 1, compute quantities of interest based on actual choices
          if (run_estimation==1){
            
            // compute log likelihood of current trial
            log_lik[s,t,v] = bernoulli_logit_lpmf(choice[s,t,v] | tau[s,v] * (ev[2]-ev[1]));
            
            // compute choice probability
            softmax_ev = softmax(tau[s,v]*ev);
            
            softmax_ev_chosen[s,t,v] = softmax_ev[choice[s,t,v]+1];
            
            // Pearce Hall learning rate
            k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k;
            k_pred[s,t,v] = k;
            
            // prediction error
            PE = outcome[s,t,v] - ev[choice[s,t,v]+1];
            PE_pred[s,t,v] = PE;
          
            // value updating (learning)
            ev[choice[s,t,v]+1] += k * PE;
            
            ev_pred[s,t,v,1] = ev[1]; // copy both evs into pred
            ev_pred[s,t,v,2] = ev[2]; // copy both evs into pred
            
            ev_chosen_pred[s,t,v] = ev[choice[s,t,v]+1];
          
          }
        
          // if estimation = 0, compute quantities of interest based on simulated choices
          if (run_estimation==0){
          
            // compute log likelihood of current trial
            log_lik[s,t,v] = bernoulli_logit_lpmf(y_pred[s,t,v] | tau[s,v] * (ev[2]-ev[1]));
          
            // Pearce Hall learning rate
            k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k;
            k_pred[s,t,v] = k;
            
            // prediction error
            PE = outcome[s,t,v] - ev[y_pred[s,t,v]+1];
            PE_pred[s,t,v] = PE;
            
            // value updating (learning)
            ev[y_pred[s,t,v]+1] += k * PE;
          
            ev_pred[s,t,v,1] = ev[1]; // copy both evs into pred
            ev_pred[s,t,v,2] = ev[2]; // copy both evs into pred
            
            ev_chosen_pred[s,t,v] = ev[y_pred[s,t,v]+1];
          
          }
        
        } // trial loop
    
      } // condition loop
  
    } // subject loop

  } // local section
  
} // generated quantities

Question 2: Is the hierarchical linear model part equivalent to the lme4 code above? Does the residual variance have to be defined as a separate parameter?
Taking alpha (A) as an example, the components should map as follows:

  • mu[1] = fixed intercept
  • A_sub_m = fixed effect of group
  • A_sub_con_m = fixed effect of condition
  • A_int_m = fixed interaction effect
  • A_vars = random intercept per subject

This post is related to some of @carinaufer 's and @simondesch 's posts Any feedback from someone familiar with frequentist hierarchical linear models would be appreciated.

Best,
Milena

Hi there,

I made some progress on answering the open questions myself, but am still unsure if my current random intercept model does what I think it does. Here my updates:

I think the hierarchical model part I adapted from @Vanessa_Brown 's code in my previous post can be translated into the following lme4 code:

alpha ~ 1 + group*condition + (1 + condition | subject) + (1 | visit)
tau ~  1 + group*condition + (1 + condition | subject) + (1 | visit)
gamma ~ 1 + group*condition + (1 + condition | subject) + (1 | visit)
C ~ 1 + group*condition + (1 + condition | subject) + (1 | visit)
  • As condition (0,1) is equivalent to the visit number (0,1) in my case and I am interested in the effect of condition, including a random intercept per condition/visit makes no sense for my purposes.

  • As I only have one observation per condition and subject, estimating a random slope for condition per subject also makes no sense.

The model I want is thus still:

alpha ~ 1 + group*condition + (1 | subject)
tau ~ 1 + group*condition + (1 | subject)
gamma ~ 1 + group*condition + (1 | subject)
C ~ 1 + group*condition + (1 | subject)

I corrected my previous Stan implementation of this model according to the really helpful examples by Julian Faraway. Specifically, the random effect per subject (e.g. A_subj_raw) is now multiplied by the variance of intercepts across subjects (e.g. A_subj_s).

// input
data {
  int<lower=1> N; // number of subjects
  int<lower=1> C; // number of visits
  int<lower=1> T; // total number of trials (observations) across subjects
  int<lower=1> MT; // max number of trials / subject / condition
  
  int<lower=1, upper=MT> Tsubj[N, C]; // actual number of trials / subject / condition
  
  int<lower=-999, upper=1> choice[N, MT, C]; // choice of correct (1) or incorrect (0) card / trial / subject / condition
  int<lower=-999, upper=1> outcome[N, MT, C];  // outcome / trial / subject / condition
  
  int kS; // number of subj-level predictors (aud_group)
  real subj_vars[N,kS]; //subj-level variable matrix (centered) - aud_group per subject
  int kV; // number of visit-level predictors (reinforcer_type)
  real visit_vars[N,C,kV]; //visit-level variable matrix (centered) (renforcer_type per subject and visit)
  
  int<lower = 0, upper = 1> run_estimation; // a switch to evaluate the likelihood
}

// transformed input 
transformed data {
  vector[2] initV;  // initial values for EV, both choices have a value of 0.5
  initV = rep_vector(0.5, 2);
  real initabsPE;
  initabsPE = 0.5; /// 
}

// output - posterior distribution should be sought
parameters {
  
  // Declare all parameters as vectors for vectorizing
  // hyperparameters (group-level means)
  vector[4] mu; // fixed intercepts for the 4 parameters, for HC and juice (these are coded as 0)

  // Subject-level raw parameters (fixed slope of aud group)
  vector[kS] A_sub_m;    // learning rate
  vector[kS] tau_sub_m;  // inverse temperature
  vector[kS] gamma_sub_m; // decay constant
  vector[kS] C_sub_m; // arbitrary constant
  
  // Condition-level raw parameters (fixed slope of reinforcer type)
  vector[kV] A_sub_con_m;    // learning rate
  vector[kV] tau_sub_con_m;  // inverse temperature
  vector[kV] gamma_sub_con_m;  // decay constant
  vector[kV] C_sub_con_m;  // arbitrary constant
  
  //cross-level interaction effects (fixed interaction effects)
  matrix[kV,kS] A_int_m;
  matrix[kV,kS] tau_int_m;
  matrix[kV,kS] gamma_int_m;
  matrix[kV,kS] C_int_m;
  
  //within subject SD (variance of random intercepts across subjects)
  real<lower=0> A_subj_s;
  real<lower=0> tau_subj_s;
  real<lower=0> gamma_subj_s;
  real<lower=0> C_subj_s;
  
  //non-centered parameterization effect on subj-level
  vector[N] A_subj_raw;
  vector[N] tau_subj_raw;
  vector[N] gamma_subj_raw;
  vector[N] C_subj_raw;
  
}

transformed parameters {
  
  // initialize condition-in-subject-level parameters
  matrix<lower=0, upper=1>[N,C] A; // bring alpha to range between 0 and 1
  matrix[N,C] A_normal; // raw alpha without range
  matrix<lower=0, upper=100>[N,C] tau; // bring tau to range between 0 and 100
  matrix[N,C] tau_normal; // tau without range
  matrix<lower=0, upper=1>[N,C] gamma; // bring gamma to range between 0 and 1
  matrix[N,C] gamma_normal; // gamma without range
  matrix<lower=0, upper=1>[N,C] C_const; // bring C to range between 0 and 1
  matrix[N,C] C_normal; // C without range

  //random intercepts per subject
  vector[N] A_vars = A_subj_s*A_subj_raw;
  vector[N] tau_vars = tau_subj_s*tau_subj_raw;
  vector[N] gamma_vars = gamma_subj_s*gamma_subj_raw;
  vector[N] C_vars = C_subj_s*C_subj_raw;
  
  // subject loop
  for (s in 1:N) {
    
    // condition loop
    for (v in 1:C) { // for every condition
      
      // fixed and random intercepts
      A_normal[s,v] = mu[1] + A_vars[s]; // fixed intercept + random intercept per subject
      tau_normal[s,v] = mu[2] + tau_vars[s];
      gamma_normal[s,v] = mu[3] + gamma_vars[s];
      C_normal[s,v] = mu[4] + C_vars[s];
      
      for (kv in 1:kV) { 
        //fixed effects of visit-level variables
        A_normal[s,v] += visit_vars[s,v,kv]*A_sub_con_m[kv]; // predictor * fixed slope
        tau_normal[s,v] += visit_vars[s,v,kv]*tau_sub_con_m[kv];
        gamma_normal[s,v] += visit_vars[s,v,kv]*gamma_sub_con_m[kv];
        C_normal[s,v] += visit_vars[s,v,kv]*C_sub_con_m[kv];
        
        for (ks in 1:kS) { 
          //fixed effects of subject-level variables
          A_normal[s,v] += subj_vars[s,ks]*A_sub_m[ks]; // predictor * fixed slope
          tau_normal[s,v] += subj_vars[s,ks]*tau_sub_m[ks];
          gamma_normal[s,v] += subj_vars[s,ks]*gamma_sub_m[ks];
          C_normal[s,v] += subj_vars[s,ks]*C_sub_m[ks];
          
          //fixed cross-level interactions
          A_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*A_int_m[ks,kv]; // predictor * predictor * fixed slope
          tau_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*tau_int_m[ks,kv];
          gamma_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*gamma_int_m[ks,kv];
          C_normal[s,v] += subj_vars[s,ks]*visit_vars[s,v,kv]*C_int_m[ks,kv];
        }
        
      }
    
    }
    
    //transform to range [0,1] or [0,100]
    A[s,] = Phi_approx(A_normal[s,]);
    tau[s,] = Phi_approx(tau_normal[s,])*100;
    gamma[s,] = Phi_approx(gamma_normal[s,]);
    C_const[s,] = Phi_approx(C_normal[s,]);

  }
  
}


model {
  
  // define prior distributions
  
  // hyperparameters (group-level means)
  mu ~ normal(0, 1);
  
  // Subject-level raw parameters
  A_sub_m ~ normal(0, 1);
  tau_sub_m ~ normal(0, 1);
  gamma_sub_m ~ normal(0, 1);
  C_sub_m ~ normal(0, 1);
  
  // Condition-level raw parameters
  A_sub_con_m ~ normal(0,1);
  tau_sub_con_m ~ normal(0,1);
  gamma_sub_con_m ~ normal(0,1);
  C_sub_con_m ~ normal(0,1);
  
  // cross-level interactions
  for (ks in 1:kS) {
    A_int_m[,ks] ~ normal(0,1);
    tau_int_m[,ks] ~ normal(0,1);
    gamma_int_m[,ks] ~ normal(0,1);
    C_int_m[,ks] ~ normal(0,1);
  }
  
  //SDs of visit-level effects across subjects
  A_subj_s ~ cauchy(0,2);
  tau_subj_s ~ cauchy(0,2);
  gamma_subj_s ~ cauchy(0,2);
  C_subj_s ~ cauchy(0,2);
    
  //NCP variance effect on subj-level effects
  A_subj_raw ~ normal(0,1);
  tau_subj_raw ~ normal(0,1);
  gamma_subj_raw ~ normal(0,1);
  C_subj_raw ~ normal(0,1);
  
  // only execute this part if we want to evaluate likelihood (fit real data)
  if (run_estimation==1){

    // subject loop
    for (s in 1:N) {
      
      // define needed variables
      vector[2] ev; // expected value for both options
      real PE;      // prediction error
      real absPE; // absolute prediction error
      real k; // learning rate per trial
    
      // condition loop
      for (v in 1:C) {
        
        // set initial values
        ev = initV;
        absPE = initabsPE;
        k = A[s,v];
      
        // trial loop
        for (t in 1:Tsubj[s,v]) {
        
          // how does choice relate to inverse temperature and action value
          choice[s,t,v] ~ bernoulli_logit(tau[s,v] * (ev[2]-ev[1])); // inverse temp * Q
          
          // Pearce Hall learning rate
          k = gamma[s,v]*C_const[s,v]*absPE + (1-gamma[s,v])*k; // decay constant * arbitrary constant * absolute PE from last trial + (1-decay constant) * learning rate from last trial
                                                    // if decay constant close to 1: dynamic learning rate will be strongly affected by PEs from last trial and only weakly affected by learning rate from previous trial (high fluctuation)
                                                    // if decay constant close to 0: dynamic learning rate will be weakly affected by PEs from last trial and strongly affected by learning rate from previous trial (low fluctuation)

        
          // prediction error
          PE = outcome[s,t,v] - ev[choice[s,t,v]+1]; // outcome - Q of choice taken
          absPE = abs(PE);
          
          // value updating (learning)
          ev[choice[s,t,v]+1] += k * PE; // Q + dynamic alpha * PE                                                                                    
      
        }
        
      }
    
    }
  
  }
  
}

When fitting this model to my data, chains do not mix (high R-hat) and I get warnings about bulk and tail ESS. Before considering if this model actually is no good description of the data-generating process, I want to make sure:

Does the Stan code do what I think it does, i.e. is the Stan code a valid implementation of the lme4 code posted above?

How could I reparameterize the model to make sampling more efficient, especially since I am already using non-centered parameterization?

Any help would be appreciated.

Best,
Milena

Have you considered using brms as it supports lme4 formulas?

Hi Milena,

I think your stan code indeed works as your lme4 syntax. Although it may not change the small effective samples, but i think replace the probit function with inv_logit function can reduce the numerical errors in sampling a little bit.

1 Like