Reduce_sum with multiple data likelihoods

Hi all,
(sorry, accidentally pressed ‘create topic’ on a draft of this…)
I’m excited to try out reduce_sum on some models I’ve been running, but I’m running into issues figuring out how to set up the reduce_sum function when I have two data likelihoods. I’m working with a model of participants’ behavior on a task, where each trial on the task consists of two choices, both of which are used to fit the model. A simplified non-reduce_sum version of my model looks like the following:

data {
  int<lower=1> nT; //trials per visit
  int<lower=1> nS; //# of subjects
  int<lower=2> nV; //# of visits per subject
  int<lower=0,upper=1> choice[nS,nT,2,nV]; 
  int<lower=0,upper=1> reward[nS,nT,nV];
  int<lower=1,upper=2> state_2[nS,nT,nV]; 
  int missing_choice[nS,nT,2,nV];
  int missing_reward[nS,nT,nV];
}
parameters {
  //[skipped] define parameters here
}

transformed parameters {
  //define transformed parameters
  matrix<lower=0,upper=1>[nS,nV] alpha;
  matrix[nS,nV] alpha_normal;
  matrix[nS,nV] beta_1_MF;
  matrix[nS,nV] beta_1_MB;
  matrix[nS,nV] beta_2;
  matrix[nS,nV] pers;

  //[skipped] create transformed parameters using non-centered parameterization for all
  // and logistic transformation for alpha (range: 0 to 1)
 
}
model {
//[skipped] priors and local variables defined here
//

for (s in 1:nS) { //loop over participants
for (s in 1:nV) { //loop over visits
    for (t in 1:nT) { //loop over trials

       //first choice
        choice[s,t,1,v] ~ bernoulli_logit(beta_1_MF[s,v]*(Q_TD[2]-Q_TD[1])
          +beta_1_MB[s,v]*(Q_MB[2]-Q_MB[1])+pers[s,v]*prev_choice);
          prev_choice = 2*choice[s,t,1,v]-1; 

       //second choice
        choice[s,t,2,v] ~ bernoulli_logit(beta_2[s,v]*
              (Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1]));

       //update transition counts: if choice=0 & state=1, or choice=1 & state=2, 
              //update 1st expectation of transition, otherwise update 2nd expectation
              tran_count = (state_2[s,t,v]-choice[s,t,1,v]-1) ? 2 : 1;
              tran_type[tran_count] = tran_type[tran_count] + 1;
              
              //update chosen values
              Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) + reward[s,t,v];
              Q_2[state_2[s,t,v],choice[s,t,2,v]+1] = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]*
                (1 -(alpha[s,v])) + reward[s,t,v];

              //update unchosen TD & second stage values
              Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*
                Q_TD[(choice[s,t,1,v] ? 1 : 2)];
              Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)] = (1-alpha[s,v])*
                Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)];
              unc_state = (state_2[s,t,v]-1) ? 1 : 2;
              Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
              Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
              
              //update model-based values
              Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
                .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
                  .7*fmax(Q_2[2,1],Q_2[2,2]));
              Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
                .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
                  .3*fmax(Q_2[2,1],Q_2[2,2]));
    }
}

I am trying to figure out how to most efficiently use reduce_sum on the

 choice[s,t,1,v] ~ bernoulli_logit(beta_1_MF[s,v]*(Q_TD[2]-Q_TD[1])
          +beta_1_MB[s,v]*(Q_MB[2]-Q_MB[1])+pers[s,v]*prev_choice);

 choice[s,t,2,v] ~ bernoulli_logit(beta_2[s,v]*
              (Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1]));

likellihoods, splitting by participants (s). Is is possible to just return the sum of the bernoulli_logit_lpmfs in the reduce_sum function for these? I was going to write separate functions for each of them but then the rest of the model would have to stay outside the function, which based on my reading would reduce the speedup provided by parallelization.

Thanks!

1 Like

I hope the many fmax calls are wrt to data in the model!

To your q: I think the answer is yes…reduce sum is for the evaluation of a large sum …not more, but that’s enough usually.

Thanks! so the function would return something like

sum(bernoulli_logit_lpmf(slice_choice[,t,1,v] | 
          beta_1_MF[start:end,v]*(Q_TD[2]-Q_TD[1])
            +beta_1_MB[start:end,v]*(Q_MB[2]-Q_MB[1])
            +pers[start:end,v]*prev_choice),
        bernoulli_logit_lpmf(slice_choice[,t,2,v] | 
          beta_2[start:end,v]*
            (Q_2[state_2[start:end,t,v],2]-Q_2[state_2[start:end,t,v],1]))

?
(just started working on reduce_sum, so the syntax may not be totally correct)

re: fmax… I wish I could say yes. I inherited the code and wasn’t sure how else to code that part, but I’ll see if I can figure something out.

Ok, I think I have things working with reduce_sum (hooray!)- at least, I am getting similar parameter estimates as with the non-reduce_sum version. I have a couple clarifying questions. I also added the full model at the end in case it’s useful for others trying to figure out reduce_sum with more complicated hierarchical models, and would of course welcome any bug catches or suggestions.

  1. Can the first term in the partial_sum function can be any term, as long as it is a 1xN arrray, where N is the dimension I want reduce_sum to slice over? Most of my variables are formatted as 2-,3-, or 4-D arrays where the first dimension is subjects (what I would like to slice over), and the remaining dimesions are visits, trials, and choice number, as applicable, and each subsequent dimension is essentially nested within the previous. It looks like things evaluated in the likelihood need to be vectorized in a 1-D array, which required me to re-format the data into a long format, and so there is no longer an indication of how the data should be sliced. I instead created a new variable (s_id) to serve this purpose and made it the first term in the partial_sum function. The examples all have the outcome variable in this role, but I noticed in the logistic regression example https://mc-stan.org/docs/2_23/stan-users-guide/reduce-sum.html it’s noted the data could be sliced by the x or y variables (“In this example, y was chosen to be sliced over because there is one term in the summation per value of y . Technically x would have worked as well. Use whatever conceptually makes the most sense for a given model, e.g. slice over independent terms like conditionally independent observations or groups of observations as in hierarchical models.”)
  1. In general, for hierarchical data, will variables need to be reformatted to 1-D arrays to work with reduce_sum? this makes indexing much messier, but I couldn’t figure out another way to get things to work.

Thanks for the help, and for these features - I’m looking forward to being able to take full advantage of cluster computing with stan!

Full code:

functions {
  real partial_sum(int[] subj_slice,
  int start, int end,
  int[,] choice_long,int[,,] reward, int[,,] state_2, int[,,,] missing_choice, 
  int[,] missing_visit,
  vector alpha_subj_raw, vector beta_1_MF_subj_raw, vector beta_1_MB_subj_raw, 
  vector beta_2_subj_raw, vector pers_subj_raw, matrix alpha_visit_raw, 
  matrix beta_1_MF_visit_raw, matrix beta_1_MB_visit_raw, 
  matrix beta_2_visit_raw, matrix pers_visit_raw,
  int nT, int nV,
  real alpha_m, real beta_1_MF_m, real beta_1_MB_m, real beta_2_m, real pers_m, 
  real alpha_subj_s, real beta_1_MF_subj_s, real beta_1_MB_subj_s, 
  real beta_2_subj_s, real pers_subj_s, real alpha_visit_s, real beta_1_MF_visit_s, 
  real beta_1_MB_visit_s, real beta_2_visit_s, real pers_visit_s) {
    
    int g_length=subj_slice[end]-subj_slice[start]+1; //number of subjects in slice
    
    //define transformed parameters
    vector[g_length*nV*nT] alpha;
    vector[g_length*nV*nT] alpha_normal;
    vector[g_length*nV*nT] beta_1_MF;
    vector[g_length*nV*nT] beta_1_MB;
    vector[g_length*nV*nT] beta_2;
    vector[g_length*nV*nT] pers;
    
    //variables for model: anything used in likelihood needs a value per trial
    vector[g_length*nV*nT] prev_choice;
    int tran_count;
    int tran_type[2];
    int unc_state;
    real Q_TD[2];
    real Q_MB[2];
    real Q_2[2,2];
    vector[g_length*nV*nT] Q_TD_diff;
    vector[g_length*nV*nT] Q_MB_diff;
    vector[g_length*nV*nT] Q_2_diff;
    
    //transformed parameters
    for (s in 1:g_length) { 
      for (v in 1:nV) {
        for (t in 1:nT) {
      alpha_normal[(s-1)*nV*nT+(v-1)*nT+t] = alpha_m + 
        alpha_visit_s*alpha_visit_raw[subj_slice[start]+s-1,v] + 
        alpha_subj_s*alpha_subj_raw[subj_slice[start]+s-1]; 
      beta_1_MF[(s-1)*nV*nT+(v-1)*nT+t] = beta_1_MF_m + 
        beta_1_MF_visit_s*beta_1_MF_visit_raw[subj_slice[start]+s-1,v] + 
        beta_1_MF_subj_s*beta_1_MF_subj_raw[subj_slice[start]+s-1];
      beta_1_MB[(s-1)*nV*nT+(v-1)*nT+t] = beta_1_MB_m + 
        beta_1_MB_visit_s*beta_1_MB_visit_raw[subj_slice[start]+s-1,v] + 
        beta_1_MB_subj_s*beta_1_MB_subj_raw[subj_slice[start]+s-1];
      beta_2[(s-1)*nV*nT+(v-1)*nT+t] = beta_2_m + 
        beta_2_visit_s*beta_2_visit_raw[subj_slice[start]+s-1,v] + 
        beta_2_subj_s*beta_2_subj_raw[subj_slice[start]+s-1];
      pers[(s-1)*nV*nT+(v-1)*nT+t] = pers_m + 
        pers_visit_s*pers_visit_raw[subj_slice[start]+s-1,v] + 
        pers_subj_s*pers_subj_raw[subj_slice[start]+s-1];
        }
      
      alpha = inv_logit(alpha_normal);
      
      //model
      // for (v in 1:nV) {
        //set initial values
        for (i in 1:2) {
          Q_TD[i]=.5;
          Q_MB[i]=.5;
          Q_2[1,i]=.5;
          Q_2[2,i]=.5;
          tran_type[i]=0;
        }
        prev_choice[(s-1)*nV*nT+(v-1)*nT+1]=0;
        
        for (t in 1:nT) {
          //use if not missing 1st stage choice
          if (missing_choice[subj_slice[start+s-1],t,1,v]==0) {
            
            //fill in values used to predict choice
            if (t<nT) prev_choice[(s-1)*nV*nT+(v-1)*nT+t+1] = 
              2*choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,1]-1; 
              //1 if choice 2, -1 if choice 1
            Q_TD_diff[(s-1)*nV*nT+(v-1)*nT+t]=Q_TD[2]-Q_TD[1];
            Q_MB_diff[(s-1)*nV*nT+(v-1)*nT+t]=Q_MB[2]-Q_MB[1];
            Q_2_diff[(s-1)*nV*nT+(v-1)*nT+t]=Q_2[state_2[subj_slice[start+s-1],t,v],2]-
              Q_2[state_2[subj_slice[start+s-1],t,v],1];
            
            //update transition counts: if choice=0 & state=1, or choice=1 & state=2, 
            //update 1st expectation of transition, otherwise update 2nd expectation
            tran_count = (state_2[subj_slice[start+s-1],t,v]-
              choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,1]-1) ? 2 : 1;
            tran_type[tran_count] = tran_type[tran_count] + 1;
            
            //update chosen values
            Q_TD[choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,1]+1] = 
              Q_TD[choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,1]+1]*
              (1-(alpha[(s-1)*nV*nT+(v-1)*nT+t])) 
              + reward[subj_slice[start+s-1],t,v];
            Q_2[state_2[subj_slice[start+s-1],t,v],
                choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,2]+1] = 
              Q_2[state_2[subj_slice[start+s-1],t,v],
                choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,2]+1]*
              (1 -(alpha[(s-1)*nV*nT+(v-1)*nT+t])) + reward[subj_slice[start+s-1],t,v];
            
            //update unchosen TD & second stage values
            Q_TD[(choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,1] ? 1 : 2)] = 
              (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*
              Q_TD[(choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,1] ? 1 : 2)];
            Q_2[state_2[subj_slice[start+s-1],t,v],
                (choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,2] ? 1 : 2)] = 
              (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*
              Q_2[state_2[subj_slice[start+s-1],t,v],
                (choice_long[(subj_slice[start+s-1]-1)*nV*nT+(v-1)*nT+t,2] ? 1 : 2)];
            unc_state = (state_2[subj_slice[start+s-1],t,v]-1) ? 1 : 2;
            Q_2[unc_state,1] = (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*Q_2[unc_state,1];
            Q_2[unc_state,2] = (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*Q_2[unc_state,2];
            
            //update model-based values
            Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
            .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
            .7*fmax(Q_2[2,1],Q_2[2,2]));
            Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
            .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
            .3*fmax(Q_2[2,1],Q_2[2,2]));
            
          } else { //if missing trial: decay all TD & 2nd stage values, 
          //update previous choice, and set trial's Q values to 0
          if (t<nT) prev_choice[(s-1)*nV*nT+(v-1)*nT+t+1]=0;
          Q_TD_diff[(s-1)*nV*nT+(v-1)*nT+t]=0;
          Q_MB_diff[(s-1)*nV*nT+(v-1)*nT+t]=0;
          Q_2_diff[(s-1)*nV*nT+(v-1)*nT+t]=0;
          Q_TD[1] = (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*Q_TD[1];
          Q_TD[2] = (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*Q_TD[2];
          Q_2[1,1] = (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*Q_2[1,1];
          Q_2[1,2] = (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*Q_2[1,2];
          Q_2[2,1] = (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*Q_2[2,1];
          Q_2[2,2] = (1-alpha[(s-1)*nV*nT+(v-1)*nT+t])*Q_2[2,2];
          }
        }
      }
    }
    
    return (bernoulli_logit_lpmf(choice_long[((start-1)*nT*nV+1):(end*nT*nV),1] | 
      beta_1_MF .* Q_TD_diff + 
      beta_1_MB .* Q_MB_diff + 
      pers .* prev_choice) + 
    (bernoulli_logit_lpmf(choice_long[((start-1)*nT*nV+1):(end*nT*nV),2] | 
      beta_2 .* Q_2_diff)));
  }
  
}

data {
  int<lower=1> nT; //trials per visit
  int<lower=1> nS; //# of subjects
  int<lower=2> nV; //# of visits per subject
  int<lower=0,upper=1> choice[nS,nT,2,nV]; 
  int<lower=0,upper=1> reward[nS,nT,nV];
  int<lower=1,upper=2> state_2[nS,nT,nV]; 
  int missing_choice[nS,nT,2,nV];
  int s_id[nS]; //seq(1,nS,by=1)
  int missing_visit[nS,nV];
}

transformed data {
  int choice_long[nS*nT*nV,2];
  for (s in 1:nS) {
    for (v in 1:nV) {
      for (t in 1:nT) {
        choice_long[(s-1)*nV*nT+(v-1)*nT+t,1]=choice[s,t,1,v];
        choice_long[(s-1)*nV*nT+(v-1)*nT+t,2]=choice[s,t,2,v];
      }
    }
  }
}

parameters {
  //group-level means (y00)
  real alpha_m;
  real<lower=0> beta_1_MF_m;
  real<lower=0> beta_1_MB_m;
  real<lower=0> beta_2_m;
  real pers_m;
  
  // subj-level variance 
  real<lower=0> alpha_subj_s;
  real<lower=0> beta_1_MF_subj_s;
  real<lower=0> beta_1_MB_subj_s;
  real<lower=0> beta_2_subj_s;
  real<lower=0> pers_subj_s;

  //NCP variance effect on subj-level effects
  vector[nS] alpha_subj_raw;
  vector[nS] beta_1_MF_subj_raw;
  vector[nS] beta_1_MB_subj_raw;
  vector[nS] beta_2_subj_raw;
  vector[nS] pers_subj_raw;

  //visit-level (within subject) SDs (sigma2_y)
  real<lower=0> alpha_visit_s;
  real<lower=0> beta_1_MF_visit_s;
  real<lower=0> beta_1_MB_visit_s;
  real<lower=0> beta_2_visit_s;
  real<lower=0> pers_visit_s;

  //non-centered parameterization (ncp) variance effect per visit & subject
  matrix[nS,nV] alpha_visit_raw;
  matrix[nS,nV] beta_1_MF_visit_raw;
  matrix[nS,nV] beta_1_MB_visit_raw;
  matrix[nS,nV] beta_2_visit_raw;
  matrix[nS,nV] pers_visit_raw;
}

//REDUCE SUM: moved to function
// transformed parameters {
//   //define transformed parameters
//   matrix<lower=0,upper=1>[nS,nV] alpha;
//   matrix[nS,nV] alpha_normal;
//   matrix[nS,nV] beta_1_MF;
//   matrix[nS,nV] beta_1_MB;
//   matrix[nS,nV] beta_2;
//   matrix[nS,nV] pers;
//
//   //create transformed parameters using non-centered parameterization for all
//   // and logistic transformation for alpha (range: 0 to 1),
//   // add in subject and visit-level effects as shifts in means
//   for (s in 1:nS) {
//     alpha_normal[s,] = alpha_m+alpha_visit_s*alpha_visit_raw[s,] + alpha_subj_s*alpha_subj_raw[s]; 
//     beta_1_MF[s,] = beta_1_MF_m + beta_1_MF_visit_s*beta_1_MF_visit_raw[s,] + 
//       beta_1_MF_subj_s*beta_1_MF_subj_raw[s];
//     beta_1_MB[s,] = beta_1_MB_m + beta_1_MB_visit_s*beta_1_MB_visit_raw[s,] + 
//       beta_1_MB_subj_s*beta_1_MB_subj_raw[s];
//     beta_2[s,] = beta_2_m + beta_2_visit_s*beta_2_visit_raw[s,] + beta_2_subj_s*beta_2_subj_raw[s];
//     pers[s,] = pers_m + pers_visit_s*pers_visit_raw[s,] + pers_subj_s*pers_subj_raw[s];
//     
//     //transform alpha to [0,1]
//     alpha[s,] = inv_logit(alpha_normal[s,]); 
//   }
// }

model {
  int grainsize=1;
  //REDUCE SUM: moved to function
  // //define variables
  // //anything used in likelihood needs a value per trial
  // vector[nT] prev_choice[nS,nV];
  // int tran_count;
  // int tran_type[2];
  // int unc_state;
  // real Q_TD[2];
  // real Q_MB[2];
  // real Q_2[2,2];
  // vector[nT] Q_TD_diff[nS,nV];
  // vector[nT] Q_MB_diff[nS,nV];
  // vector[nT] Q_2_diff[nS,nV];
  
  //define priors
  alpha_m ~ normal(0,2.5);
  beta_1_MF_m ~ normal(0,5);
  beta_1_MB_m ~ normal(0,5);
  beta_2_m ~ normal(0,5);
  pers_m ~ normal(0,2.5);
  
  alpha_visit_s ~ student_t(3,0,2);
  beta_1_MF_visit_s ~ student_t(3,0,2);
  beta_1_MB_visit_s ~ student_t(3,0,2);
  beta_2_visit_s ~ student_t(3,0,2);
  pers_visit_s ~ student_t(3,0,2);
  
  for (s in 1:nS) {
    alpha_visit_raw[s,] ~ normal(0,1);
    beta_1_MF_visit_raw[s,] ~ normal(0,1);
    beta_1_MB_visit_raw[s,] ~ normal(0,1);
    beta_2_visit_raw[s,] ~ normal(0,1);
    pers_visit_raw[s,] ~ normal(0,1);
  }
  
  alpha_subj_raw ~ normal(0,1);
  beta_1_MF_subj_raw ~ normal(0,1);
  beta_1_MB_subj_raw ~ normal(0,1);
  beta_2_subj_raw ~ normal(0,1);
  pers_subj_raw ~ normal(0,1);
  
  alpha_subj_s ~ student_t(3,0,2);
  beta_1_MF_subj_s ~ student_t(3,0,3);
  beta_1_MB_subj_s ~ student_t(3,0,3);
  beta_2_subj_s ~ student_t(3,0,3);
  pers_subj_s ~ student_t(3,0,2);
  
  target += reduce_sum(partial_sum,s_id,grainsize,
  choice_long, reward,state_2, missing_choice, missing_visit,
  alpha_subj_raw, beta_1_MF_subj_raw, beta_1_MB_subj_raw,beta_2_subj_raw, pers_subj_raw, 
  alpha_visit_raw, beta_1_MF_visit_raw, beta_1_MB_visit_raw, beta_2_visit_raw, pers_visit_raw,
  nT, nV, alpha_m, beta_1_MF_m, beta_1_MB_m, beta_2_m, pers_m, alpha_subj_s, 
  beta_1_MF_subj_s, beta_1_MB_subj_s, beta_2_subj_s, pers_subj_s, alpha_visit_s, 
  beta_1_MF_visit_s, beta_1_MB_visit_s, beta_2_visit_s, pers_visit_s);
  
  //REDUCE SUM: moved to function
  // for (s in 1:nS) {
  //   for (v in 1:nV) {
  //     
  //     //set initial values
  //     for (i in 1:2) {
  //       Q_TD[i]=.5;
  //       Q_MB[i]=.5;
  //       Q_2[1,i]=.5;
  //       Q_2[2,i]=.5;
  //       tran_type[i]=0;
  //     }
  //     prev_choice[s,v,1]=0;
  //     
  //     for (t in 1:nT) {
  //       //use if not missing 1st stage choice
  //       if (missing_choice[s,t,1,v]==0) {
  //         
  //         //fill in values used to predict choice
  //         if (t<nT) prev_choice[s,v,t+1] = 2*choice[s,t,1,v]-1; //1 if choice 2, -1 if choice 1
  //         Q_TD_diff[s,v,t]=Q_TD[2]-Q_TD[1];
  //         Q_MB_diff[s,v,t]=Q_MB[2]-Q_MB[1];
  //         Q_2_diff[s,v,t]=Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1];
  //         
  //         //update transition counts: if choice=0 & state=1, or choice=1 & state=2, 
  //         //update 1st expectation of transition, otherwise update 2nd expectation
  //         tran_count = (state_2[s,t,v]-choice[s,t,1,v]-1) ? 2 : 1;
  //         tran_type[tran_count] = tran_type[tran_count] + 1;
  //         
  //         //update chosen values
  //         Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) 
  //           + reward[s,t,v];
  //         Q_2[state_2[s,t,v],choice[s,t,2,v]+1] = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]*
  //         (1 -(alpha[s,v])) + reward[s,t,v];
  // 
  //         //update unchosen TD & second stage values
  //         Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*
  //         Q_TD[(choice[s,t,1,v] ? 1 : 2)];
  //         Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)] = (1-alpha[s,v])*
  //         Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)];
  //         unc_state = (state_2[s,t,v]-1) ? 1 : 2;
  //         Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
  //         Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
  //         
  //         //update model-based values
  //         Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
  //         .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
  //         .7*fmax(Q_2[2,1],Q_2[2,2]));
  //         Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
  //         .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
  //         .3*fmax(Q_2[2,1],Q_2[2,2]));
  //         
  //       } else { //if missing trial: decay all TD & 2nd stage values, 
  //       //update previous choice, and set trial's Q values to 0
  //       if (t<nT) prev_choice[s,v,t+1]=0;
  //       Q_TD_diff[s,v,t]=0;
  //       Q_MB_diff[s,v,t]=0;
  //       Q_2_diff[s,v,t]=0;
  //       Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
  //       Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
  //       Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
  //       Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
  //       Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
  //       Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
  //       }
  //     }
  //     choice[s,,1,v] ~ bernoulli_logit(beta_1_MF[s,v]*Q_TD_diff[s,v]
  //       +beta_1_MB[s,v]*Q_MB_diff[s,v] +pers[s,v]*prev_choice[s,v]);
  //     choice[s,,2,v] ~ bernoulli_logit(beta_2[s,v]*Q_2_diff[s,v]);
  //   }
  //   
  // }
}

generated quantities {
  //same code as above, with following changes: 
  // 1) values and choices used to calculate probability, rather than fitting values to choices
  // 2) no priors, etc.- uses estimated pararamter values from model block
  
  real log_lik[nS,nT,2,nV]; //log likelihood- must be named this
  int prev_choice;
  int tran_count;
  int tran_type[2];
  int unc_state;
  real Q_TD[2];
  real Q_MB[2];
  real Q_2[2,2];
  
  //REDUCE SUM:  add transformed parameters here since no longer defined above 
  //define transformed parameters
  matrix<lower=0,upper=1>[nS,nV] alpha;
  matrix[nS,nV] alpha_normal;
  matrix[nS,nV] beta_1_MF;
  matrix[nS,nV] beta_1_MB;
  matrix[nS,nV] beta_2;
  matrix[nS,nV] pers;
  
  //create transformed parameters using non-centered parameterization for all
  // and logistic transformation for alpha (range: 0 to 1),
  // add in subject and visit-level effects as shifts in means
  for (s in 1:nS) {
    alpha_normal[s,] = alpha_m+alpha_visit_s*alpha_visit_raw[s,] + alpha_subj_s*alpha_subj_raw[s];
    beta_1_MF[s,] = beta_1_MF_m + beta_1_MF_visit_s*beta_1_MF_visit_raw[s,] +
    beta_1_MF_subj_s*beta_1_MF_subj_raw[s];
    beta_1_MB[s,] = beta_1_MB_m + beta_1_MB_visit_s*beta_1_MB_visit_raw[s,] +
    beta_1_MB_subj_s*beta_1_MB_subj_raw[s];
    beta_2[s,] = beta_2_m + beta_2_visit_s*beta_2_visit_raw[s,] + beta_2_subj_s*beta_2_subj_raw[s];
    pers[s,] = pers_m + pers_visit_s*pers_visit_raw[s,] + pers_subj_s*pers_subj_raw[s];

    //transform alpha to [0,1]
    alpha[s,] = inv_logit(alpha_normal[s,]);
  }

  for (s in 1:nS) {
    for (v in 1:nV) {
      for (i in 1:2) {
        Q_TD[i]=.5;
        Q_MB[i]=.5;
        Q_2[1,i]=.5;
        Q_2[2,i]=.5;
        tran_type[i]=0;
      }
      prev_choice=0;
      for (t in 1:nT) {
        if (missing_choice[s,t,1,v]==0) {
          log_lik[s,t,1,v] = bernoulli_logit_lpmf(choice[s,t,1,v] | beta_1_MF[s,v]*
          (Q_TD[2]-Q_TD[1])+beta_1_MB[s,v]*(Q_MB[2]-Q_MB[1])+pers[s,v]*prev_choice);
          prev_choice = 2*choice[s,t,1,v]-1; //1 if choice 2, -1 if choice 1
          
          log_lik[s,t,2,v] = bernoulli_logit_lpmf(choice[s,t,2,v] | beta_2[s,v]*
          (Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1]));
          
          //update transition counts: if choice=0 & state=1, or choice=1 & state=2, 
          //update 1st expectation of transition, otherwise update 2nd expectation
          tran_count = (state_2[s,t,v]-choice[s,t,1,v]-1) ? 2 : 1;
          tran_type[tran_count] = tran_type[tran_count] + 1;
          
          //update chosen values
          Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) + reward[s,t,v];
          Q_2[state_2[s,t,v],choice[s,t,2,v]+1] = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]*
          (1 -(alpha[s,v])) + reward[s,t,v];

          //update unchosen TD & second stage values
          Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*Q_TD[(choice[s,t,1,v] ? 1 : 2)];
          Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)] = (1-alpha[s,v])*
          Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)];
          unc_state = (state_2[s,t,v]-1) ? 1 : 2;
          Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
          Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
          
          //update model-based values
          Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
          .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
          .7*fmax(Q_2[2,1],Q_2[2,2]));
          Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
          .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
          .3*fmax(Q_2[2,1],Q_2[2,2]));
          
        } else { //if missing 1st stage choice: decay all TD & 2nd stage values & 
        //update previous choice
        prev_choice=0;
        log_lik[s,t,1,v] = 0;
        log_lik[s,t,2,v] = 0;
        Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
        Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
        Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
        Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
        Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
        Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
        }
      } 
    }
  }
}

1 Like

The first argument can have any dimension. Say you have N items, then

matrix[5,5] large[N];

is fine… or

real large[N, 5, 5];

is also ok.

No. It’s just that the first argument must be an array - but it can be an array of any data structure… and you can actually put a dummy argument in the first position if that’s easier for you.

I see, thanks! So it slices over the first dimension of the first argument?

But the likelihood will still need to be returned as a single sum, right? That was the other part that was tripping me up, since I’m working with a hierarchical model of timeseries data- my higher levels (subjects, visits) are conditionally independent but my lower levels (trials) are not. I had been evaluating the likelihood in a loop but things need to vectorized for reduce_sum.

More generally: if I have variables in the format real [N,X,Y], where Y is nested in X and X is nested in N, and I can slice over N (or X) but not Y - is there another way to return the likelihood than reformatting things in the format real [N*X*Y]?

But conditional on your parameters everything should be independent…provided the model works that way. So if each loop iteration is independent of the other loop iterations, then you are just building up a large sum.

And yes, reduce sum simply slices over rows.

Yeah, I think that’s what I’m struggling with - I have some parts of my model that are not independent (i.e., if I switch around the order of the trials within a visit/subject, the results would change) and some that are conditionally independent (visits/subjects) and I’m trying to make sure I slice over the right things. But I think I mostly get things now. The other issue I’m struggling with in terms of the array dimensions has more to do with practical coding problems and multiplying 2-D and 3-D arrays, and I’ll figure those out. thanks!

based on @wds15 's info and thinking about how to best structure my model, I cleaned up the indexing/variable dimensions some and things run MUCH faster now - I think my current model (pasted below) is good enough to work with. The reduce_sum version in my previous post worked in terms of returning accurate parameter estimates, but ran 4-6x slower on one core per chain than my non-reduce_sum code, which would have negated any parallelization benefits. The current version only runs 10-15% slower than the original on one core, which I can live with, and is already faster than the original when I run two cores/chain. Thanks again for the help!

functions {
  real partial_sum(int[] subj_slice,
  int start, int end,
  int[,] choice_long,int[,,] reward, int[,,] state_2, int[,,,] missing_choice, 
  int[,] missing_visit, int[,,,]choice,
  vector alpha_subj_raw, vector beta_1_MF_subj_raw, vector beta_1_MB_subj_raw, 
  vector beta_2_subj_raw, vector pers_subj_raw, matrix alpha_visit_raw, 
  matrix beta_1_MF_visit_raw, matrix beta_1_MB_visit_raw, 
  matrix beta_2_visit_raw, matrix pers_visit_raw,
  int nT, int nV,
  real alpha_m, real beta_1_MF_m, real beta_1_MB_m, real beta_2_m, real pers_m, 
  real alpha_subj_s, real beta_1_MF_subj_s, real beta_1_MB_subj_s, 
  real beta_2_subj_s, real pers_subj_s, real alpha_visit_s, real beta_1_MF_visit_s, 
  real beta_1_MB_visit_s, real beta_2_visit_s, real pers_visit_s) {
    
    int g_length=subj_slice[end]-subj_slice[start]+1; //number of subjects in slice
    
    //define transformed parameters
    matrix[g_length,nV] alpha;
    matrix[g_length,nV] alpha_normal;
    matrix[g_length,nV] beta_1_MF;
    matrix[g_length,nV] beta_1_MB;
    matrix[g_length,nV] beta_2;
    matrix[g_length,nV] pers;
    
    //variables for model: anything used in likelihood needs a value per trial
    int prev_choice;
    int tran_count;
    int tran_type[2];
    int unc_state;
    real Q_TD[2];
    real Q_MB[2];
    real Q_2[2,2];
    //add 1-D arrays for inputs into likelihood
    vector[g_length*nV*nT] SF_TD;
    vector[g_length*nV*nT] SF_MB;
    vector[g_length*nV*nT] SF_P;
    vector[g_length*nV*nT] SF_2;
    
    //transformed parameters
    for (s in 1:g_length) { 
      for (v in 1:nV) {
        alpha_normal[s,v] = alpha_m + 
         alpha_visit_s*alpha_visit_raw[subj_slice[start]+s-1,v] + 
          alpha_subj_s*alpha_subj_raw[subj_slice[start]+s-1]; 
          beta_1_MF[s,v] = beta_1_MF_m + 
          beta_1_MF_visit_s*beta_1_MF_visit_raw[subj_slice[start]+s-1,v] + 
          beta_1_MF_subj_s*beta_1_MF_subj_raw[subj_slice[start]+s-1];
        beta_1_MB[s,v] = beta_1_MB_m + 
          beta_1_MB_visit_s*beta_1_MB_visit_raw[subj_slice[start]+s-1,v] + 
          beta_1_MB_subj_s*beta_1_MB_subj_raw[subj_slice[start]+s-1];
        beta_2[s,v] = beta_2_m + 
          beta_2_visit_s*beta_2_visit_raw[subj_slice[start]+s-1,v] + 
          beta_2_subj_s*beta_2_subj_raw[subj_slice[start]+s-1];
        pers[s,v] = pers_m + 
          pers_visit_s*pers_visit_raw[subj_slice[start]+s-1,v] + 
          pers_subj_s*pers_subj_raw[subj_slice[start]+s-1];
        // }
        alpha[s,v] = inv_logit(alpha_normal[s,v]);
        
        //model
        // for (v in 1:nV) {
          //set initial values
          for (i in 1:2) {
            Q_TD[i]=.5;
            Q_MB[i]=.5;
            Q_2[1,i]=.5;
            Q_2[2,i]=.5;
            tran_type[i]=0;
          }
          prev_choice=0;
          
          for (t in 1:nT) {
            //use if not missing 1st stage choice
            if (missing_choice[subj_slice[start]+s-1,t,1,v]==0) {
              
              //fill in values used to predict choice
              SF_TD[(s-1)*nV*nT+(v-1)*nT+t]=beta_1_MF[s,v]*(Q_TD[2]-Q_TD[1]);
              SF_MB[(s-1)*nV*nT+(v-1)*nT+t]=beta_1_MB[s,v]*(Q_MB[2]-Q_MB[1]);
              SF_P[(s-1)*nV*nT+(v-1)*nT+t]=pers[s,v]*prev_choice;
              SF_2[(s-1)*nV*nT+(v-1)*nT+t]=beta_2[s,v]*(Q_2[state_2[subj_slice[start]+s-1,t,v],2]-
              Q_2[state_2[subj_slice[start]+s-1,t,v],1]);
              
              prev_choice = 2*choice[(subj_slice[start]+s-1),t,1,v]-1; 
              //1 if choice 2, -1 if choice 1
              
              //update transition counts: if choice=0 & state=1, or choice=1 & state=2, 
              //update 1st expectation of transition, otherwise update 2nd expectation
              tran_count = (state_2[subj_slice[start]+s-1,t,v]-
              choice[(subj_slice[start]+s-1),t,1,v]-1) ? 2 : 1;
              tran_type[tran_count] = tran_type[tran_count] + 1;
              
              //update chosen values
              Q_TD[choice[subj_slice[start]+s-1,t,1,v]+1] = 
              Q_TD[choice[subj_slice[start]+s-1,t,1,v]+1]*(1-(alpha[s,v])) 
              + reward[subj_slice[start]+s-1,t,v];
              Q_2[state_2[subj_slice[start]+s-1,t,v],choice[subj_slice[start]+s-1,t,2,v]+1] = 
              Q_2[state_2[subj_slice[start]+s-1,t,v],choice[subj_slice[start]+s-1,t,2,v]+1]*
              (1 -(alpha[s,v])) + reward[subj_slice[start]+s-1,t,v];
              
              //update unchosen TD & second stage values
              Q_TD[(choice[subj_slice[start]+s-1,t,1,v] ? 1 : 2)] = 
              (1-alpha[s,v])*Q_TD[(choice[subj_slice[start]+s-1,t,1,v] ? 1 : 2)];
              Q_2[state_2[subj_slice[start]+s-1,t,v],(choice[subj_slice[start]+s-1,t,2,v] ? 1 : 2)] = 
              (1-alpha[s,v])*Q_2[state_2[subj_slice[start]+s-1,t,v],
              (choice[subj_slice[start]+s-1,t,2,v] ? 1 : 2)];
              unc_state = (state_2[subj_slice[start]+s-1,t,v]-1) ? 1 : 2;
              Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
              Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
              
              //update model-based values
              Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
              .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
              .7*fmax(Q_2[2,1],Q_2[2,2]));
              Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
              .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
              .3*fmax(Q_2[2,1],Q_2[2,2]));
              
            } else { //if missing trial: decay all TD & 2nd stage values, 
            //update previous choice, and set trial's Q values to 0
            
            SF_TD[(s-1)*nV*nT+(v-1)*nT+t]=0;
            SF_MB[(s-1)*nV*nT+(v-1)*nT+t]=0;
            SF_P[(s-1)*nV*nT+(v-1)*nT+t]=0;
            SF_2[(s-1)*nV*nT+(v-1)*nT+t]=0;
            prev_choice=0;
            Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
            Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
            Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
            Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
            Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
            Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
            }
          }
      }
    }
    
    return (bernoulli_logit_lpmf(choice_long[((start-1)*nT*nV+1):(end*nT*nV),1] | 
    SF_TD + SF_MB + SF_P) + 
    bernoulli_logit_lpmf(choice_long[((start-1)*nT*nV+1):(end*nT*nV),2] | SF_2));
  }
  
}

data {
  int<lower=1> nT; //trials per visit
  int<lower=1> nS; //# of subjects
  int<lower=2> nV; //# of visits per subject
  int<lower=0,upper=1> choice[nS,nT,2,nV]; 
  int<lower=0,upper=1> reward[nS,nT,nV];
  int<lower=1,upper=2> state_2[nS,nT,nV]; 
  int missing_choice[nS,nT,2,nV];
  int s_id[nS]; //seq(1,nS,by=1)
  int missing_visit[nS,nV];
}

transformed data {
  int choice_long[nS*nT*nV,2];
  for (s in 1:nS) {
    for (v in 1:nV) {
      for (t in 1:nT) {
        choice_long[(s-1)*nV*nT+(v-1)*nT+t,1]=choice[s,t,1,v];
        choice_long[(s-1)*nV*nT+(v-1)*nT+t,2]=choice[s,t,2,v];
      }
    }
  }
}

parameters {
  //group-level means (y00)
  real alpha_m;
  real<lower=0> beta_1_MF_m;
  real<lower=0> beta_1_MB_m;
  real<lower=0> beta_2_m;
  real pers_m;
  
  // subj-level variance 
  real<lower=0> alpha_subj_s;
  real<lower=0> beta_1_MF_subj_s;
  real<lower=0> beta_1_MB_subj_s;
  real<lower=0> beta_2_subj_s;
  real<lower=0> pers_subj_s;

  //NCP variance effect on subj-level effects
  vector[nS] alpha_subj_raw;
  vector[nS] beta_1_MF_subj_raw;
  vector[nS] beta_1_MB_subj_raw;
  vector[nS] beta_2_subj_raw;
  vector[nS] pers_subj_raw;

  //visit-level (within subject) SDs (sigma2_y)
  real<lower=0> alpha_visit_s;
  real<lower=0> beta_1_MF_visit_s;
  real<lower=0> beta_1_MB_visit_s;
  real<lower=0> beta_2_visit_s;
  real<lower=0> pers_visit_s;

  //non-centered parameterization (ncp) variance effect per visit & subject
  matrix[nS,nV] alpha_visit_raw;
  matrix[nS,nV] beta_1_MF_visit_raw;
  matrix[nS,nV] beta_1_MB_visit_raw;
  matrix[nS,nV] beta_2_visit_raw;
  matrix[nS,nV] pers_visit_raw;
}

//REDUCE SUM: moved to function
// transformed parameters {
//   //define transformed parameters
//   matrix<lower=0,upper=1>[nS,nV] alpha;
//   matrix[nS,nV] alpha_normal;
//   matrix[nS,nV] beta_1_MF;
//   matrix[nS,nV] beta_1_MB;
//   matrix[nS,nV] beta_2;
//   matrix[nS,nV] pers;
//
//   //create transformed parameters using non-centered parameterization for all
//   // and logistic transformation for alpha (range: 0 to 1),
//   // add in subject and visit-level effects as shifts in means
//   for (s in 1:nS) {
//     alpha_normal[s,] = alpha_m+alpha_visit_s*alpha_visit_raw[s,] + alpha_subj_s*alpha_subj_raw[s]; 
//     beta_1_MF[s,] = beta_1_MF_m + beta_1_MF_visit_s*beta_1_MF_visit_raw[s,] + 
//       beta_1_MF_subj_s*beta_1_MF_subj_raw[s];
//     beta_1_MB[s,] = beta_1_MB_m + beta_1_MB_visit_s*beta_1_MB_visit_raw[s,] + 
//       beta_1_MB_subj_s*beta_1_MB_subj_raw[s];
//     beta_2[s,] = beta_2_m + beta_2_visit_s*beta_2_visit_raw[s,] + beta_2_subj_s*beta_2_subj_raw[s];
//     pers[s,] = pers_m + pers_visit_s*pers_visit_raw[s,] + pers_subj_s*pers_subj_raw[s];
//     
//     //transform alpha to [0,1]
//     alpha[s,] = inv_logit(alpha_normal[s,]); 
//   }
// }

model {
  int grainsize=1;
  //REDUCE SUM: moved to function
  // //define variables
  // //anything used in likelihood needs a value per trial
  // vector[nT] prev_choice[nS,nV];
  // int tran_count;
  // int tran_type[2];
  // int unc_state;
  // real Q_TD[2];
  // real Q_MB[2];
  // real Q_2[2,2];
  // vector[nT] Q_TD_diff[nS,nV];
  // vector[nT] Q_MB_diff[nS,nV];
  // vector[nT] Q_2_diff[nS,nV];
  
  //define priors
  alpha_m ~ normal(0,2.5);
  beta_1_MF_m ~ normal(0,5);
  beta_1_MB_m ~ normal(0,5);
  beta_2_m ~ normal(0,5);
  pers_m ~ normal(0,2.5);
  
  alpha_visit_s ~ student_t(3,0,2);
  beta_1_MF_visit_s ~ student_t(3,0,2);
  beta_1_MB_visit_s ~ student_t(3,0,2);
  beta_2_visit_s ~ student_t(3,0,2);
  pers_visit_s ~ student_t(3,0,2);
  
  for (s in 1:nS) {
    alpha_visit_raw[s,] ~ normal(0,1);
    beta_1_MF_visit_raw[s,] ~ normal(0,1);
    beta_1_MB_visit_raw[s,] ~ normal(0,1);
    beta_2_visit_raw[s,] ~ normal(0,1);
    pers_visit_raw[s,] ~ normal(0,1);
  }
  
  alpha_subj_raw ~ normal(0,1);
  beta_1_MF_subj_raw ~ normal(0,1);
  beta_1_MB_subj_raw ~ normal(0,1);
  beta_2_subj_raw ~ normal(0,1);
  pers_subj_raw ~ normal(0,1);
  
  alpha_subj_s ~ student_t(3,0,2);
  beta_1_MF_subj_s ~ student_t(3,0,3);
  beta_1_MB_subj_s ~ student_t(3,0,3);
  beta_2_subj_s ~ student_t(3,0,3);
  pers_subj_s ~ student_t(3,0,2);
  
  target += reduce_sum(partial_sum,s_id,grainsize,
  choice_long, reward,state_2, missing_choice, missing_visit,choice,
  alpha_subj_raw, beta_1_MF_subj_raw, beta_1_MB_subj_raw,beta_2_subj_raw, pers_subj_raw, 
  alpha_visit_raw, beta_1_MF_visit_raw, beta_1_MB_visit_raw, beta_2_visit_raw, pers_visit_raw,
  nT, nV, alpha_m, beta_1_MF_m, beta_1_MB_m, beta_2_m, pers_m, alpha_subj_s, 
  beta_1_MF_subj_s, beta_1_MB_subj_s, beta_2_subj_s, pers_subj_s, alpha_visit_s, 
  beta_1_MF_visit_s, beta_1_MB_visit_s, beta_2_visit_s, pers_visit_s);
  
  //REDUCE SUM: moved to function
  // for (s in 1:nS) {
  //   for (v in 1:nV) {
  //     
  //     //set initial values
  //     for (i in 1:2) {
  //       Q_TD[i]=.5;
  //       Q_MB[i]=.5;
  //       Q_2[1,i]=.5;
  //       Q_2[2,i]=.5;
  //       tran_type[i]=0;
  //     }
  //     prev_choice[s,v,1]=0;
  //     
  //     for (t in 1:nT) {
  //       //use if not missing 1st stage choice
  //       if (missing_choice[s,t,1,v]==0) {
  //         
  //         //fill in values used to predict choice
  //         if (t<nT) prev_choice[s,v,t+1] = 2*choice[s,t,1,v]-1; //1 if choice 2, -1 if choice 1
  //         Q_TD_diff[s,v,t]=Q_TD[2]-Q_TD[1];
  //         Q_MB_diff[s,v,t]=Q_MB[2]-Q_MB[1];
  //         Q_2_diff[s,v,t]=Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1];
  //         
  //         //update transition counts: if choice=0 & state=1, or choice=1 & state=2, 
  //         //update 1st expectation of transition, otherwise update 2nd expectation
  //         tran_count = (state_2[s,t,v]-choice[s,t,1,v]-1) ? 2 : 1;
  //         tran_type[tran_count] = tran_type[tran_count] + 1;
  //         
  //         //update chosen values
  //         Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) 
  //           + reward[s,t,v];
  //         Q_2[state_2[s,t,v],choice[s,t,2,v]+1] = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]*
  //         (1 -(alpha[s,v])) + reward[s,t,v];
  // 
  //         //update unchosen TD & second stage values
  //         Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*
  //         Q_TD[(choice[s,t,1,v] ? 1 : 2)];
  //         Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)] = (1-alpha[s,v])*
  //         Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)];
  //         unc_state = (state_2[s,t,v]-1) ? 1 : 2;
  //         Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
  //         Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
  //         
  //         //update model-based values
  //         Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
  //         .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
  //         .7*fmax(Q_2[2,1],Q_2[2,2]));
  //         Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
  //         .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
  //         .3*fmax(Q_2[2,1],Q_2[2,2]));
  //         
  //       } else { //if missing trial: decay all TD & 2nd stage values, 
  //       //update previous choice, and set trial's Q values to 0
  //       if (t<nT) prev_choice[s,v,t+1]=0;
  //       Q_TD_diff[s,v,t]=0;
  //       Q_MB_diff[s,v,t]=0;
  //       Q_2_diff[s,v,t]=0;
  //       Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
  //       Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
  //       Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
  //       Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
  //       Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
  //       Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
  //       }
  //     }
  //     choice[s,,1,v] ~ bernoulli_logit(beta_1_MF[s,v]*Q_TD_diff[s,v]
  //       +beta_1_MB[s,v]*Q_MB_diff[s,v] +pers[s,v]*prev_choice[s,v]);
  //     choice[s,,2,v] ~ bernoulli_logit(beta_2[s,v]*Q_2_diff[s,v]);
  //   }
  //   
  // }
}

generated quantities {
  //same code as above, with following changes: 
  // 1) values and choices used to calculate probability, rather than fitting values to choices
  // 2) no priors, etc.- uses estimated pararamter values from model block
  
  real log_lik[nS,nT,2,nV]; //log likelihood- must be named this
  int prev_choice;
  int tran_count;
  int tran_type[2];
  int unc_state;
  real Q_TD[2];
  real Q_MB[2];
  real Q_2[2,2];
  
  //REDUCE SUM:  add transformed parameters here since no longer defined above 
  //define transformed parameters
  matrix<lower=0,upper=1>[nS,nV] alpha;
  matrix[nS,nV] alpha_normal;
  matrix[nS,nV] beta_1_MF;
  matrix[nS,nV] beta_1_MB;
  matrix[nS,nV] beta_2;
  matrix[nS,nV] pers;
  
  //create transformed parameters using non-centered parameterization for all
  // and logistic transformation for alpha (range: 0 to 1),
  // add in subject and visit-level effects as shifts in means
  for (s in 1:nS) {
    alpha_normal[s,] = alpha_m+alpha_visit_s*alpha_visit_raw[s,] + alpha_subj_s*alpha_subj_raw[s];
    beta_1_MF[s,] = beta_1_MF_m + beta_1_MF_visit_s*beta_1_MF_visit_raw[s,] +
    beta_1_MF_subj_s*beta_1_MF_subj_raw[s];
    beta_1_MB[s,] = beta_1_MB_m + beta_1_MB_visit_s*beta_1_MB_visit_raw[s,] +
    beta_1_MB_subj_s*beta_1_MB_subj_raw[s];
    beta_2[s,] = beta_2_m + beta_2_visit_s*beta_2_visit_raw[s,] + beta_2_subj_s*beta_2_subj_raw[s];
    pers[s,] = pers_m + pers_visit_s*pers_visit_raw[s,] + pers_subj_s*pers_subj_raw[s];

    //transform alpha to [0,1]
    alpha[s,] = inv_logit(alpha_normal[s,]);
  }

  for (s in 1:nS) {
    for (v in 1:nV) {
      for (i in 1:2) {
        Q_TD[i]=.5;
        Q_MB[i]=.5;
        Q_2[1,i]=.5;
        Q_2[2,i]=.5;
        tran_type[i]=0;
      }
      prev_choice=0;
      for (t in 1:nT) {
        if (missing_choice[s,t,1,v]==0) {
          log_lik[s,t,1,v] = bernoulli_logit_lpmf(choice[s,t,1,v] | beta_1_MF[s,v]*
          (Q_TD[2]-Q_TD[1])+beta_1_MB[s,v]*(Q_MB[2]-Q_MB[1])+pers[s,v]*prev_choice);
          prev_choice = 2*choice[s,t,1,v]-1; //1 if choice 2, -1 if choice 1
          
          log_lik[s,t,2,v] = bernoulli_logit_lpmf(choice[s,t,2,v] | beta_2[s,v]*
          (Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1]));
          
          //update transition counts: if choice=0 & state=1, or choice=1 & state=2, 
          //update 1st expectation of transition, otherwise update 2nd expectation
          tran_count = (state_2[s,t,v]-choice[s,t,1,v]-1) ? 2 : 1;
          tran_type[tran_count] = tran_type[tran_count] + 1;
          
          //update chosen values
          Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) + reward[s,t,v];
          Q_2[state_2[s,t,v],choice[s,t,2,v]+1] = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]*
          (1 -(alpha[s,v])) + reward[s,t,v];

          //update unchosen TD & second stage values
          Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*Q_TD[(choice[s,t,1,v] ? 1 : 2)];
          Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)] = (1-alpha[s,v])*
          Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)];
          unc_state = (state_2[s,t,v]-1) ? 1 : 2;
          Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
          Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
          
          //update model-based values
          Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
          .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
          .7*fmax(Q_2[2,1],Q_2[2,2]));
          Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) + 
          .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) + 
          .3*fmax(Q_2[2,1],Q_2[2,2]));
          
        } else { //if missing 1st stage choice: decay all TD & 2nd stage values & 
        //update previous choice
        prev_choice=0;
        log_lik[s,t,1,v] = 0;
        log_lik[s,t,2,v] = 0;
        Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
        Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
        Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
        Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
        Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
        Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
        }
      } 
    }
  }
}

Does reduce sum has the benefit on auto differentiation?

yes

You should replace this with for each parameter with

‘to_vector(xxx) ~ std_normal()’

thanks!