Repeated measure hierarchical reinforcement learning with groups (2 x 2 x 2 design)

Hi Carina,

for the poor model fits (high Rhat): it sounds like you have three blocks per condition, two conditions per session, and repeated sessions per person (where people are observed/not observed and given drug/placebo), correct? Are you modeling all three blocks in a condition together? That would be the correct way to do things if I understand your design right, and would give you 48 trials per condition which is probably enough.

  1. I would recommend moving away from the phi_approx (I think hbayesdm carried that over from JAGS), especially for your inverse temperature parameter. There’s no reason for a transformation or upper bound on that parameter, as long as you set a lower bound (using <lower=0>). I use inv_logit to constrain learning rates but phi_approx may work ok for that parameter. I agree with the recommendation to use prior predictive checks to check the range of your priors.
  2. in your code, the choices are from -1 to 2, and you use a categorical_logit to predict them. If you only have two choices per block, you can use bernoulli_logit and have choices indicated as 0 or 1. What does the -1 to 2 coding for your choices mean now?
  3. it sounds like overall, participants are learning well. Is that the case for most participants? in my experience, hierarchical modeling can handle some poor learners fine, but if it gets to be many people, that may start causing issues.

comparison to null model: I’m not sure what the paper you linked to was doing, but a null/chance model is just assuming that the chance of each choice is 50%, right? If you want to do model comparison with BIC (though other methods can be used with stan), you just need the number of trials (which is constant), the number of parameters, and the likelihood per trial. For a null model, the number of parameters is 0 and the likelihood per trial is 0.5. So, once you calculate that, you can compare that to the BIC you get from a model with parameters. Note that the fitted models give you a LL per post-warmup iteration so you will want to average over those.

estimating effects within and across sessions: This is going to be tricky. I agree with the above recommendations that you get models for each condition figured out first, as those’ll need to be in good shape before you can start looking at group differences. Below I’ve pasted code for a model I’m working with (different task - this is Daw’s two-step model-based/model-free task) that incorporates repeated measures and within- and across-subject effects. You’ll still need to add in the within-condition effects but hopefully this gives you a starting point. I’ve checked this code pretty thoroughly, but noone else really has, so use at your own risk :)

data {
  int<lower=1> nT; //trials per visit
  int<lower=1> nS; //# of subjects
  int<lower=2> nV; //# of visits per subject
  int<lower=0,upper=1> choice[nS,nT,2,nV]; 
  int<lower=0,upper=1> reward[nS,nT,nV];
  int<lower=1,upper=2> state_2[nS,nT,nV]; 
  int missing_choice[nS,nT,2,nV];
  int missing_reward[nS,nT,nV];
  int kS; //# of subj-level variables (L)
  real subj_vars[nS,kS]; //subj-level variable matrix (centered)
  int kV; //# of visit-level variables (K)
  real visit_vars[nS,nV,kV]; //visit-level variable matrix (centered) (X)
  real missing_visit[nS,nV];
}

parameters {
  //group-level means (y00)
  real alpha_m;
  real<lower=0> beta_1_MF_m;
  real<lower=0> beta_1_MB_m;
  real<lower=0> beta_2_m;
  real pers_m;
  
  //subj-level effects (y01:yks)
  vector[kS] alpha_grp_m;
  vector[kS] beta_1_MF_grp_m;
  vector[kS] beta_1_MB_grp_m;
  vector[kS] beta_2_grp_m;
  vector[kS] pers_grp_m;
  
  //visit-level effects - mean (y10:ykv0)
  vector[kV] alpha_visit_grp_m;
  vector[kV] beta_1_MF_visit_grp_m;
  vector[kV] beta_1_MB_visit_grp_m;
  vector[kV] beta_2_visit_grp_m;
  vector[kV] pers_visit_grp_m;
  
  //cross-level interaction effects (y11:ykvks)
  matrix[kV,kS] alpha_int_m;
  matrix[kV,kS] beta_1_MF_int_m;
  matrix[kV,kS] beta_1_MB_int_m;
  matrix[kV,kS] beta_2_int_m;
  matrix[kV,kS] pers_int_m;
  
  //visit-level (within subject) SDs (sigma2_y)
  real<lower=0> alpha_visit_s;
  real<lower=0> beta_1_MF_visit_s;
  real<lower=0> beta_1_MB_visit_s;
  real<lower=0> beta_2_visit_s;
  real<lower=0> pers_visit_s;
  
  //SDs of visit-level effects
  vector<lower=0>[kV+1] alpha_subj_s;
  vector<lower=0>[kV+1] beta_1_MF_subj_s;
  vector<lower=0>[kV+1] beta_1_MB_subj_s;
  vector<lower=0>[kV+1] beta_2_subj_s;
  vector<lower=0>[kV+1] pers_subj_s;
  
  //non-centered parameterization (ncp) variance effect per visit & subject
  matrix[nS,nV] alpha_visit_raw;
  matrix[nS,nV] beta_1_MF_visit_raw;
  matrix[nS,nV] beta_1_MB_visit_raw;
  matrix[nS,nV] beta_2_visit_raw;
  matrix[nS,nV] pers_visit_raw;
  
  //NCP variance effect on subj-level effects
  matrix[kV+1,nS] alpha_subj_raw;
  matrix[kV+1,nS] beta_1_MF_subj_raw;
  matrix[kV+1,nS] beta_1_MB_subj_raw;
  matrix[kV+1,nS] beta_2_subj_raw;
  matrix[kV+1,nS] pers_subj_raw;
  
  //Cholesky factors of correlation matrices for subj-level variances
  cholesky_factor_corr[kV+1] alpha_subj_L;
  cholesky_factor_corr[kV+1] beta_1_MF_subj_L;
  cholesky_factor_corr[kV+1] beta_1_MB_subj_L;
  cholesky_factor_corr[kV+1] beta_2_subj_L;
  cholesky_factor_corr[kV+1] pers_subj_L;
}

transformed parameters {
  //define transformed parameters
  matrix<lower=0,upper=1>[nS,nV] alpha;
  matrix[nS,nV] alpha_normal;
  matrix[nS,nV] beta_1_MF;
  matrix[nS,nV] beta_1_MB;
  matrix[nS,nV] beta_2;
  matrix[nS,nV] pers;
  
  //convert Cholesky factorized correlation matrix into SDs per visit-level effect
  matrix[nS,kV+1] alpha_vars = (diag_pre_multiply(alpha_subj_s,alpha_subj_L)*alpha_subj_raw)';
  matrix[nS,kV+1] beta_1_MF_vars = (diag_pre_multiply(beta_1_MF_subj_s,
    beta_1_MF_subj_L)*beta_1_MF_subj_raw)';
  matrix[nS,kV+1] beta_1_MB_vars = (diag_pre_multiply(beta_1_MB_subj_s,
    beta_1_MB_subj_L)*beta_1_MB_subj_raw)';
  matrix[nS,kV+1] beta_2_vars = (diag_pre_multiply(beta_2_subj_s,
    beta_2_subj_L)*beta_2_subj_raw)';
  matrix[nS,kV+1] pers_vars = (diag_pre_multiply(pers_subj_s,pers_subj_L)*pers_subj_raw)';
  
  //create transformed parameters using non-centered parameterization for all
  // and logistic transformation for alpha (range: 0 to 1),
  // add in subject and visit-level effects as shifts in means
  for (s in 1:nS) {
    alpha_normal[s,] = alpha_m+alpha_visit_s*alpha_visit_raw[s,] + alpha_vars[s,1]; 
    beta_1_MF[s,] = beta_1_MF_m + beta_1_MF_visit_s*beta_1_MF_visit_raw[s,] + 
      beta_1_MF_vars[s,1];
    beta_1_MB[s,] = beta_1_MB_m + beta_1_MB_visit_s*beta_1_MB_visit_raw[s,] + 
      beta_1_MB_vars[s,1];
    beta_2[s,] = beta_2_m + beta_2_visit_s*beta_2_visit_raw[s,] + beta_2_vars[s,1];
    pers[s,] = pers_m + pers_visit_s*pers_visit_raw[s,] + pers_vars[s,1];
    
    // //add subj-level effects
    // for (k in 1:kS) { 
    //   alpha_normal[s,] += alpha_grp_m[k]*subj_vars[s,k];
    //   beta_1_MF[s,] += beta_1_MF_grp_m[k]*subj_vars[s,k];
    //   beta_1_MB[s,] += beta_1_MB_grp_m[k]*subj_vars[s,k];
    //   beta_2[s,] += beta_2_grp_m[k]*subj_vars[s,k];
    //   pers[s,] += pers_grp_m[k]*subj_vars[s,k];
    // }
  
    //add subj- and visit-level effects
    for (v in 1:nV) {
      if (missing_visit[s,v]==0) {
      for (kk in 1:kV) { 
        //main effects of visit-level variables
        alpha_normal[s,v] += visit_vars[s,v,kk]*(alpha_visit_grp_m[kk]+alpha_vars[s,kk+1]);
        beta_1_MF[s,v] += visit_vars[s,v,kk]*(beta_1_MF_visit_grp_m[kk]+beta_1_MF_vars[s,kk+1]);
        beta_1_MB[s,v] += visit_vars[s,v,kk]*(beta_1_MB_visit_grp_m[kk]+beta_1_MB_vars[s,kk+1]);
        beta_2[s,v] += visit_vars[s,v,kk]*(beta_2_visit_grp_m[kk]+beta_2_vars[s,kk+1]);
        pers[s,v] += visit_vars[s,v,kk]*(pers_visit_grp_m[kk]+pers_vars[s,kk+1]);
          for (k in 1:kS) { 
            //main effects of subject-level variables
            alpha_normal[s,v] += subj_vars[s,k]*alpha_grp_m[k];
            beta_1_MF[s,v] += beta_1_MF_grp_m[k]*subj_vars[s,k];
            beta_1_MB[s,v] += beta_1_MB_grp_m[k]*subj_vars[s,k];
            beta_2[s,v] += beta_2_grp_m[k]*subj_vars[s,k];
             pers[s,v] += pers_grp_m[k]*subj_vars[s,k];
            
            //cross-level interactions
            alpha_normal[s,v] += subj_vars[s,k]*visit_vars[s,v,kk]*alpha_int_m[kk,k];
            beta_1_MF[s,v] += subj_vars[s,k]*visit_vars[s,v,kk]*beta_1_MF_int_m[kk,k];
            beta_1_MB[s,v] += subj_vars[s,k]*visit_vars[s,v,kk]*beta_1_MB_int_m[kk,k];
            beta_2[s,v] += subj_vars[s,k]*visit_vars[s,v,kk]*beta_2_int_m[kk,k];
            pers[s,v] += subj_vars[s,k]*visit_vars[s,v,kk]*pers_int_m[kk,k];
          }
      }
      }
    }
    //transform alpha to [0,1]
    alpha[s,] = inv_logit(alpha_normal[s,]); 
  }
}

model {
  //define variables
  int prev_choice;
  int tran_count;
  int tran_type[2];
  int unc_state;
  // real delta_1;
  // real delta_2;
  real Q_TD[2];
  real Q_MB[2];
  real Q_2[2,2];
  
  //define priors
  alpha_m ~ normal(0,2.5);
  beta_1_MF_m ~ normal(0,5);
  beta_1_MB_m ~ normal(0,5);
  beta_2_m ~ normal(0,5);
  pers_m ~ normal(0,2.5);
  
  alpha_grp_m ~ normal(0,1);
  beta_1_MF_grp_m ~ normal(0,1);
  beta_1_MB_grp_m ~ normal(0,1);
  beta_2_grp_m ~ normal(0,1);
  pers_grp_m ~ normal(0,1);
  
  alpha_visit_grp_m ~ normal(0,1);
  beta_1_MF_visit_grp_m ~ normal(0,1);
  beta_1_MB_visit_grp_m ~ normal(0,1);
  beta_2_visit_grp_m ~ normal(0,1);
  pers_visit_grp_m ~ normal(0,1);
  
  for (k in 1:kS) {
    alpha_int_m[,k] ~ normal(0,1);
    beta_1_MF_int_m[,k] ~ normal(0,1);
    beta_1_MB_int_m[,k] ~ normal(0,1);
    beta_2_int_m[,k] ~ normal(0,1);
    pers_int_m[,k] ~ normal(0,1);
  }
  
  alpha_visit_s ~ cauchy(0,2);
  beta_1_MF_visit_s ~ cauchy(0,2);
  beta_1_MB_visit_s ~ cauchy(0,2);
  beta_2_visit_s ~ cauchy(0,2);
  pers_visit_s ~ cauchy(0,2);
  
  for (s in 1:nS) {
    alpha_visit_raw[s,] ~ normal(0,1);
    beta_1_MF_visit_raw[s,] ~ normal(0,1);
    beta_1_MB_visit_raw[s,] ~ normal(0,1);
    beta_2_visit_raw[s,] ~ normal(0,1);
    pers_visit_raw[s,] ~ normal(0,1);
    
    to_vector(alpha_subj_raw[,s]) ~ normal(0,1);
    to_vector(beta_1_MF_subj_raw[,s]) ~ normal(0,1);
    to_vector(beta_1_MB_subj_raw[,s]) ~ normal(0,1);
    to_vector(beta_2_subj_raw[,s]) ~ normal(0,1);
    to_vector(pers_subj_raw[,s]) ~ normal(0,1);
  }
  
  alpha_subj_s ~ student_t(3,0,2);
  beta_1_MF_subj_s ~ student_t(3,0,3);
  beta_1_MB_subj_s ~ student_t(3,0,3);
  beta_2_subj_s ~ student_t(3,0,3);
  pers_subj_s ~ student_t(3,0,2);
  
  alpha_subj_L ~ lkj_corr_cholesky(1);
  beta_1_MF_subj_L ~ lkj_corr_cholesky(1);
  beta_1_MB_subj_L ~ lkj_corr_cholesky(1);
  beta_2_subj_L ~ lkj_corr_cholesky(1);
  pers_subj_L ~ lkj_corr_cholesky(1);
  
  for (s in 1:nS) {
    for (v in 1:nV) {
      
      //set initial values
      for (i in 1:2) {
        Q_TD[i]=.5;
        Q_MB[i]=.5;
        Q_2[1,i]=.5;
        Q_2[2,i]=.5;
        tran_type[i]=0;
      }
      prev_choice=0;
      
      for (t in 1:nT) {
        //use if not missing 1st stage choice
        if (missing_choice[s,t,1,v]==0) {
          choice[s,t,1,v] ~ bernoulli_logit(beta_1_MF[s,v]*(Q_TD[2]-Q_TD[1])
          +beta_1_MB[s,v]*(Q_MB[2]-Q_MB[1])+pers[s,v]*prev_choice);
          prev_choice = 2*choice[s,t,1,v]-1; //1 if choice 2, -1 if choice 1
          
          //use if not missing 2nd stage choice
          if (missing_choice[s,t,2,v]==0) {
            choice[s,t,2,v] ~ bernoulli_logit(beta_2[s,v]*
              (Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1]));
            
            //use if not missing 2nd stage reward
            if (missing_reward[s,t,v]==0) {
              // //prediction errors
              // //note: choices are 0/1, +1 to make them 1/2 for indexing
              // delta_1 = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]/alpha[s,v]-
              //   Q_TD[choice[s,t,1,v]+1]; 
              // delta_2 = reward[s,t,v]/alpha[s,v] - Q_2[state_2[s,t,v],choice[s,t,2,v]+1];
              
              //update transition counts: if choice=0 & state=1, or choice=1 & state=2, 
              //update 1st expectation of transition, otherwise update 2nd expectation
              tran_count = (state_2[s,t,v]-choice[s,t,1,v]-1) ? 2 : 1;
              tran_type[tran_count] = tran_type[tran_count] + 1;
              
              //update chosen values
              //Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1] + 
              //  alpha[s,v]*(delta_1+lambda[s,v]*delta_2);
              Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) + reward[s,t,v];
                //+ alpha[s,v]*delta_1+delta_2;
              Q_2[state_2[s,t,v],choice[s,t,2,v]+1] = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]*
                (1 -(alpha[s,v])) + reward[s,t,v];
              // + alpha[s,v]*delta_2;
  
              //update unchosen TD & second stage values
              Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*
                Q_TD[(choice[s,t,1,v] ? 1 : 2)];
              Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)] = (1-alpha[s,v])*
                Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)];
              unc_state = (state_2[s,t,v]-1) ? 1 : 2;
              Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
              Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
              
              //update model-based values
              Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
                .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
                  .7*fmax(Q_2[2,1],Q_2[2,2]));
              Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
                .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
                  .3*fmax(Q_2[2,1],Q_2[2,2]));
              
            } //if missing 2nd stage reward: do nothing
            //if missing 2nd stage choice or reward: still update 1st stage TD values, 
            //decay 2nd stage values
          } else if (missing_choice[s,t,2,v]==1||missing_reward[s,t,v]==1) { 
          // delta_1 = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]-Q_TD[choice[s,t,1,v]+1];
          Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) + reward[s,t,v];
          //+ alpha[s,v]*delta_1;
          Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*Q_TD[(choice[s,t,1,v] ? 1 : 2)];
          Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
          Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
          Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
          Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
          //MB update of first stage values based on second stage values, so don't change
          
          }
        } else { //if missing 1st stage choice: decay all TD & 2nd stage values & 
        //update previous choice
        prev_choice=0;
        Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
        Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
        Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
        Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
        Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
        Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
        }
      }
    }
  }
}

generated quantities {
  //same code as above, with following changes: 
  // 1) values and choices used to calculate probability, rather than fitting values to choices
  // 2) no priors, etc.- uses estimated pararamter values from model block
  
  real log_lik[nS,nT,2,nV]; //log likelihood- must be named this
  int prev_choice;
  int tran_count;
  int tran_type[2];
  int unc_state;
  // real delta_1;
  // real delta_2;
  real Q_TD[2];
  real Q_MB[2];
  real Q_2[2,2];
  
  corr_matrix[kV+1] alpha_cor = multiply_lower_tri_self_transpose(alpha_subj_L);
  corr_matrix[kV+1] beta_1_MF_cor = 
    multiply_lower_tri_self_transpose(beta_1_MF_subj_L);
  corr_matrix[kV+1] beta_1_MB_cor = 
    multiply_lower_tri_self_transpose(beta_1_MB_subj_L);
  corr_matrix[kV+1] beta_2_cor = 
    multiply_lower_tri_self_transpose(beta_2_subj_L);
  corr_matrix[kV+1] pers_cor = 
    multiply_lower_tri_self_transpose(pers_subj_L);
  
  for (s in 1:nS) {
    for (v in 1:nV) {
      for (i in 1:2) {
        Q_TD[i]=.5;
        Q_MB[i]=.5;
        Q_2[1,i]=.5;
        Q_2[2,i]=.5;
        tran_type[i]=0;
      }
      prev_choice=0;
      for (t in 1:nT) {
        if (missing_choice[s,t,1,v]==0) {
          log_lik[s,t,1,v] = bernoulli_logit_lpmf(choice[s,t,1,v] | beta_1_MF[s,v]*
            (Q_TD[2]-Q_TD[1])+beta_1_MB[s,v]*(Q_MB[2]-Q_MB[1])+pers[s,v]*prev_choice);
          prev_choice = 2*choice[s,t,1,v]-1; //1 if choice 2, -1 if choice 1
          
          if (missing_choice[s,t,2,v]==0) {
            log_lik[s,t,2,v] = bernoulli_logit_lpmf(choice[s,t,2,v] | beta_2[s,v]*
              (Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1]));
            
            //use if not missing 2nd stage reward
            if (missing_reward[s,t,v]==0) {
              // //prediction errors
              // //note: choices are 0/1, +1 to make them 1/2 for indexing
              // delta_1 = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]/alpha[s,v]-
              //   Q_TD[choice[s,t,1,v]+1];
              // delta_2 = reward[s,t,v]/alpha[s,v] - Q_2[state_2[s,t,v],choice[s,t,2,v]+1];
              
              //update transition counts: if choice=0 & state=1, or choice=1 & state=2, 
              //update 1st expectation of transition, otherwise update 2nd expectation
              tran_count = (state_2[s,t,v]-choice[s,t,1,v]-1) ? 2 : 1;
              tran_type[tran_count] = tran_type[tran_count] + 1;
              
              //update chosen values
              //Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1] + alpha[s,v]*
              //  (delta_1+lambda[s,v]*delta_2);
              Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) + reward[s,t,v];
                //+ alpha[s,v]*delta_1+delta_2;
              Q_2[state_2[s,t,v],choice[s,t,2,v]+1] = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]*
                (1 -(alpha[s,v])) + reward[s,t,v];
              // + alpha[s,v]*delta_2;
              
              //update unchosen TD & second stage values
              Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*Q_TD[(choice[s,t,1,v] ? 1 : 2)];
              Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)] = (1-alpha[s,v])*
                Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)];
              unc_state = (state_2[s,t,v]-1) ? 1 : 2;
              Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
              Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
              
              //update model-based values
              Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
                .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
                  .7*fmax(Q_2[2,1],Q_2[2,2]));
              Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
                .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
                  .3*fmax(Q_2[2,1],Q_2[2,2]));
              
            } //if missing 2nd stage reward: do nothing
           //if missing 2nd stage choice or reward: still update 1st stage TD values, 
           //decay 2nd stage values 
          } else if (missing_choice[s,t,2,v]==1||missing_reward[s,t,v]==1) { 
          log_lik[s,t,2,v] = 0;
          // delta_1 = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]-Q_TD[choice[s,t,1,v]+1]; 
          Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) + reward[s,t,v];
          //+ alpha[s,v]*delta_1;
          Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*Q_TD[(choice[s,t,1,v] ? 1 : 2)];
          Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
          Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
          Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
          Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
          //MB update of first stage values based on second stage values, so don't change
          
          }
        } else { //if missing 1st stage choice: decay all TD & 2nd stage values & 
         //update previous choice
        prev_choice=0;
        log_lik[s,t,1,v] = 0;
        log_lik[s,t,2,v] = 0;
        Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
        Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
        Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
        Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
        Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
        Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
        }
      } 
    }
  }
}

general resources: below are a bunch of things I’ve found useful. most are not stan-specific, and are at different levels depending on where you’re at in terms of learning modeling.

Ten simple rules for computational modeling of behavioral data
Video tutorial on modeling behavioral data
Computational psychiatry tutorial on modeling individual differences
Using reinforcement learning models in social neuroscience
Trial-by-trial data analysis using computational models note that you can download this from here
Basic modeling tutorial

best of luck,
Vanessa

2 Likes