based on @wds15 's info and thinking about how to best structure my model, I cleaned up the indexing/variable dimensions some and things run MUCH faster now - I think my current model (pasted below) is good enough to work with. The reduce_sum version in my previous post worked in terms of returning accurate parameter estimates, but ran 4-6x slower on one core per chain than my non-reduce_sum code, which would have negated any parallelization benefits. The current version only runs 10-15% slower than the original on one core, which I can live with, and is already faster than the original when I run two cores/chain. Thanks again for the help!
functions {
real partial_sum(int[] subj_slice,
int start, int end,
int[,] choice_long,int[,,] reward, int[,,] state_2, int[,,,] missing_choice,
int[,] missing_visit, int[,,,]choice,
vector alpha_subj_raw, vector beta_1_MF_subj_raw, vector beta_1_MB_subj_raw,
vector beta_2_subj_raw, vector pers_subj_raw, matrix alpha_visit_raw,
matrix beta_1_MF_visit_raw, matrix beta_1_MB_visit_raw,
matrix beta_2_visit_raw, matrix pers_visit_raw,
int nT, int nV,
real alpha_m, real beta_1_MF_m, real beta_1_MB_m, real beta_2_m, real pers_m,
real alpha_subj_s, real beta_1_MF_subj_s, real beta_1_MB_subj_s,
real beta_2_subj_s, real pers_subj_s, real alpha_visit_s, real beta_1_MF_visit_s,
real beta_1_MB_visit_s, real beta_2_visit_s, real pers_visit_s) {
int g_length=subj_slice[end]-subj_slice[start]+1; //number of subjects in slice
//define transformed parameters
matrix[g_length,nV] alpha;
matrix[g_length,nV] alpha_normal;
matrix[g_length,nV] beta_1_MF;
matrix[g_length,nV] beta_1_MB;
matrix[g_length,nV] beta_2;
matrix[g_length,nV] pers;
//variables for model: anything used in likelihood needs a value per trial
int prev_choice;
int tran_count;
int tran_type[2];
int unc_state;
real Q_TD[2];
real Q_MB[2];
real Q_2[2,2];
//add 1-D arrays for inputs into likelihood
vector[g_length*nV*nT] SF_TD;
vector[g_length*nV*nT] SF_MB;
vector[g_length*nV*nT] SF_P;
vector[g_length*nV*nT] SF_2;
//transformed parameters
for (s in 1:g_length) {
for (v in 1:nV) {
alpha_normal[s,v] = alpha_m +
alpha_visit_s*alpha_visit_raw[subj_slice[start]+s-1,v] +
alpha_subj_s*alpha_subj_raw[subj_slice[start]+s-1];
beta_1_MF[s,v] = beta_1_MF_m +
beta_1_MF_visit_s*beta_1_MF_visit_raw[subj_slice[start]+s-1,v] +
beta_1_MF_subj_s*beta_1_MF_subj_raw[subj_slice[start]+s-1];
beta_1_MB[s,v] = beta_1_MB_m +
beta_1_MB_visit_s*beta_1_MB_visit_raw[subj_slice[start]+s-1,v] +
beta_1_MB_subj_s*beta_1_MB_subj_raw[subj_slice[start]+s-1];
beta_2[s,v] = beta_2_m +
beta_2_visit_s*beta_2_visit_raw[subj_slice[start]+s-1,v] +
beta_2_subj_s*beta_2_subj_raw[subj_slice[start]+s-1];
pers[s,v] = pers_m +
pers_visit_s*pers_visit_raw[subj_slice[start]+s-1,v] +
pers_subj_s*pers_subj_raw[subj_slice[start]+s-1];
// }
alpha[s,v] = inv_logit(alpha_normal[s,v]);
//model
// for (v in 1:nV) {
//set initial values
for (i in 1:2) {
Q_TD[i]=.5;
Q_MB[i]=.5;
Q_2[1,i]=.5;
Q_2[2,i]=.5;
tran_type[i]=0;
}
prev_choice=0;
for (t in 1:nT) {
//use if not missing 1st stage choice
if (missing_choice[subj_slice[start]+s-1,t,1,v]==0) {
//fill in values used to predict choice
SF_TD[(s-1)*nV*nT+(v-1)*nT+t]=beta_1_MF[s,v]*(Q_TD[2]-Q_TD[1]);
SF_MB[(s-1)*nV*nT+(v-1)*nT+t]=beta_1_MB[s,v]*(Q_MB[2]-Q_MB[1]);
SF_P[(s-1)*nV*nT+(v-1)*nT+t]=pers[s,v]*prev_choice;
SF_2[(s-1)*nV*nT+(v-1)*nT+t]=beta_2[s,v]*(Q_2[state_2[subj_slice[start]+s-1,t,v],2]-
Q_2[state_2[subj_slice[start]+s-1,t,v],1]);
prev_choice = 2*choice[(subj_slice[start]+s-1),t,1,v]-1;
//1 if choice 2, -1 if choice 1
//update transition counts: if choice=0 & state=1, or choice=1 & state=2,
//update 1st expectation of transition, otherwise update 2nd expectation
tran_count = (state_2[subj_slice[start]+s-1,t,v]-
choice[(subj_slice[start]+s-1),t,1,v]-1) ? 2 : 1;
tran_type[tran_count] = tran_type[tran_count] + 1;
//update chosen values
Q_TD[choice[subj_slice[start]+s-1,t,1,v]+1] =
Q_TD[choice[subj_slice[start]+s-1,t,1,v]+1]*(1-(alpha[s,v]))
+ reward[subj_slice[start]+s-1,t,v];
Q_2[state_2[subj_slice[start]+s-1,t,v],choice[subj_slice[start]+s-1,t,2,v]+1] =
Q_2[state_2[subj_slice[start]+s-1,t,v],choice[subj_slice[start]+s-1,t,2,v]+1]*
(1 -(alpha[s,v])) + reward[subj_slice[start]+s-1,t,v];
//update unchosen TD & second stage values
Q_TD[(choice[subj_slice[start]+s-1,t,1,v] ? 1 : 2)] =
(1-alpha[s,v])*Q_TD[(choice[subj_slice[start]+s-1,t,1,v] ? 1 : 2)];
Q_2[state_2[subj_slice[start]+s-1,t,v],(choice[subj_slice[start]+s-1,t,2,v] ? 1 : 2)] =
(1-alpha[s,v])*Q_2[state_2[subj_slice[start]+s-1,t,v],
(choice[subj_slice[start]+s-1,t,2,v] ? 1 : 2)];
unc_state = (state_2[subj_slice[start]+s-1,t,v]-1) ? 1 : 2;
Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
//update model-based values
Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) +
.3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) +
.7*fmax(Q_2[2,1],Q_2[2,2]));
Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) +
.7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) +
.3*fmax(Q_2[2,1],Q_2[2,2]));
} else { //if missing trial: decay all TD & 2nd stage values,
//update previous choice, and set trial's Q values to 0
SF_TD[(s-1)*nV*nT+(v-1)*nT+t]=0;
SF_MB[(s-1)*nV*nT+(v-1)*nT+t]=0;
SF_P[(s-1)*nV*nT+(v-1)*nT+t]=0;
SF_2[(s-1)*nV*nT+(v-1)*nT+t]=0;
prev_choice=0;
Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
}
}
}
}
return (bernoulli_logit_lpmf(choice_long[((start-1)*nT*nV+1):(end*nT*nV),1] |
SF_TD + SF_MB + SF_P) +
bernoulli_logit_lpmf(choice_long[((start-1)*nT*nV+1):(end*nT*nV),2] | SF_2));
}
}
data {
int<lower=1> nT; //trials per visit
int<lower=1> nS; //# of subjects
int<lower=2> nV; //# of visits per subject
int<lower=0,upper=1> choice[nS,nT,2,nV];
int<lower=0,upper=1> reward[nS,nT,nV];
int<lower=1,upper=2> state_2[nS,nT,nV];
int missing_choice[nS,nT,2,nV];
int s_id[nS]; //seq(1,nS,by=1)
int missing_visit[nS,nV];
}
transformed data {
int choice_long[nS*nT*nV,2];
for (s in 1:nS) {
for (v in 1:nV) {
for (t in 1:nT) {
choice_long[(s-1)*nV*nT+(v-1)*nT+t,1]=choice[s,t,1,v];
choice_long[(s-1)*nV*nT+(v-1)*nT+t,2]=choice[s,t,2,v];
}
}
}
}
parameters {
//group-level means (y00)
real alpha_m;
real<lower=0> beta_1_MF_m;
real<lower=0> beta_1_MB_m;
real<lower=0> beta_2_m;
real pers_m;
// subj-level variance
real<lower=0> alpha_subj_s;
real<lower=0> beta_1_MF_subj_s;
real<lower=0> beta_1_MB_subj_s;
real<lower=0> beta_2_subj_s;
real<lower=0> pers_subj_s;
//NCP variance effect on subj-level effects
vector[nS] alpha_subj_raw;
vector[nS] beta_1_MF_subj_raw;
vector[nS] beta_1_MB_subj_raw;
vector[nS] beta_2_subj_raw;
vector[nS] pers_subj_raw;
//visit-level (within subject) SDs (sigma2_y)
real<lower=0> alpha_visit_s;
real<lower=0> beta_1_MF_visit_s;
real<lower=0> beta_1_MB_visit_s;
real<lower=0> beta_2_visit_s;
real<lower=0> pers_visit_s;
//non-centered parameterization (ncp) variance effect per visit & subject
matrix[nS,nV] alpha_visit_raw;
matrix[nS,nV] beta_1_MF_visit_raw;
matrix[nS,nV] beta_1_MB_visit_raw;
matrix[nS,nV] beta_2_visit_raw;
matrix[nS,nV] pers_visit_raw;
}
//REDUCE SUM: moved to function
// transformed parameters {
// //define transformed parameters
// matrix<lower=0,upper=1>[nS,nV] alpha;
// matrix[nS,nV] alpha_normal;
// matrix[nS,nV] beta_1_MF;
// matrix[nS,nV] beta_1_MB;
// matrix[nS,nV] beta_2;
// matrix[nS,nV] pers;
//
// //create transformed parameters using non-centered parameterization for all
// // and logistic transformation for alpha (range: 0 to 1),
// // add in subject and visit-level effects as shifts in means
// for (s in 1:nS) {
// alpha_normal[s,] = alpha_m+alpha_visit_s*alpha_visit_raw[s,] + alpha_subj_s*alpha_subj_raw[s];
// beta_1_MF[s,] = beta_1_MF_m + beta_1_MF_visit_s*beta_1_MF_visit_raw[s,] +
// beta_1_MF_subj_s*beta_1_MF_subj_raw[s];
// beta_1_MB[s,] = beta_1_MB_m + beta_1_MB_visit_s*beta_1_MB_visit_raw[s,] +
// beta_1_MB_subj_s*beta_1_MB_subj_raw[s];
// beta_2[s,] = beta_2_m + beta_2_visit_s*beta_2_visit_raw[s,] + beta_2_subj_s*beta_2_subj_raw[s];
// pers[s,] = pers_m + pers_visit_s*pers_visit_raw[s,] + pers_subj_s*pers_subj_raw[s];
//
// //transform alpha to [0,1]
// alpha[s,] = inv_logit(alpha_normal[s,]);
// }
// }
model {
int grainsize=1;
//REDUCE SUM: moved to function
// //define variables
// //anything used in likelihood needs a value per trial
// vector[nT] prev_choice[nS,nV];
// int tran_count;
// int tran_type[2];
// int unc_state;
// real Q_TD[2];
// real Q_MB[2];
// real Q_2[2,2];
// vector[nT] Q_TD_diff[nS,nV];
// vector[nT] Q_MB_diff[nS,nV];
// vector[nT] Q_2_diff[nS,nV];
//define priors
alpha_m ~ normal(0,2.5);
beta_1_MF_m ~ normal(0,5);
beta_1_MB_m ~ normal(0,5);
beta_2_m ~ normal(0,5);
pers_m ~ normal(0,2.5);
alpha_visit_s ~ student_t(3,0,2);
beta_1_MF_visit_s ~ student_t(3,0,2);
beta_1_MB_visit_s ~ student_t(3,0,2);
beta_2_visit_s ~ student_t(3,0,2);
pers_visit_s ~ student_t(3,0,2);
for (s in 1:nS) {
alpha_visit_raw[s,] ~ normal(0,1);
beta_1_MF_visit_raw[s,] ~ normal(0,1);
beta_1_MB_visit_raw[s,] ~ normal(0,1);
beta_2_visit_raw[s,] ~ normal(0,1);
pers_visit_raw[s,] ~ normal(0,1);
}
alpha_subj_raw ~ normal(0,1);
beta_1_MF_subj_raw ~ normal(0,1);
beta_1_MB_subj_raw ~ normal(0,1);
beta_2_subj_raw ~ normal(0,1);
pers_subj_raw ~ normal(0,1);
alpha_subj_s ~ student_t(3,0,2);
beta_1_MF_subj_s ~ student_t(3,0,3);
beta_1_MB_subj_s ~ student_t(3,0,3);
beta_2_subj_s ~ student_t(3,0,3);
pers_subj_s ~ student_t(3,0,2);
target += reduce_sum(partial_sum,s_id,grainsize,
choice_long, reward,state_2, missing_choice, missing_visit,choice,
alpha_subj_raw, beta_1_MF_subj_raw, beta_1_MB_subj_raw,beta_2_subj_raw, pers_subj_raw,
alpha_visit_raw, beta_1_MF_visit_raw, beta_1_MB_visit_raw, beta_2_visit_raw, pers_visit_raw,
nT, nV, alpha_m, beta_1_MF_m, beta_1_MB_m, beta_2_m, pers_m, alpha_subj_s,
beta_1_MF_subj_s, beta_1_MB_subj_s, beta_2_subj_s, pers_subj_s, alpha_visit_s,
beta_1_MF_visit_s, beta_1_MB_visit_s, beta_2_visit_s, pers_visit_s);
//REDUCE SUM: moved to function
// for (s in 1:nS) {
// for (v in 1:nV) {
//
// //set initial values
// for (i in 1:2) {
// Q_TD[i]=.5;
// Q_MB[i]=.5;
// Q_2[1,i]=.5;
// Q_2[2,i]=.5;
// tran_type[i]=0;
// }
// prev_choice[s,v,1]=0;
//
// for (t in 1:nT) {
// //use if not missing 1st stage choice
// if (missing_choice[s,t,1,v]==0) {
//
// //fill in values used to predict choice
// if (t<nT) prev_choice[s,v,t+1] = 2*choice[s,t,1,v]-1; //1 if choice 2, -1 if choice 1
// Q_TD_diff[s,v,t]=Q_TD[2]-Q_TD[1];
// Q_MB_diff[s,v,t]=Q_MB[2]-Q_MB[1];
// Q_2_diff[s,v,t]=Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1];
//
// //update transition counts: if choice=0 & state=1, or choice=1 & state=2,
// //update 1st expectation of transition, otherwise update 2nd expectation
// tran_count = (state_2[s,t,v]-choice[s,t,1,v]-1) ? 2 : 1;
// tran_type[tran_count] = tran_type[tran_count] + 1;
//
// //update chosen values
// Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v]))
// + reward[s,t,v];
// Q_2[state_2[s,t,v],choice[s,t,2,v]+1] = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]*
// (1 -(alpha[s,v])) + reward[s,t,v];
//
// //update unchosen TD & second stage values
// Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*
// Q_TD[(choice[s,t,1,v] ? 1 : 2)];
// Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)] = (1-alpha[s,v])*
// Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)];
// unc_state = (state_2[s,t,v]-1) ? 1 : 2;
// Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
// Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
//
// //update model-based values
// Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) +
// .3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) +
// .7*fmax(Q_2[2,1],Q_2[2,2]));
// Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) +
// .7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) +
// .3*fmax(Q_2[2,1],Q_2[2,2]));
//
// } else { //if missing trial: decay all TD & 2nd stage values,
// //update previous choice, and set trial's Q values to 0
// if (t<nT) prev_choice[s,v,t+1]=0;
// Q_TD_diff[s,v,t]=0;
// Q_MB_diff[s,v,t]=0;
// Q_2_diff[s,v,t]=0;
// Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
// Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
// Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
// Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
// Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
// Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
// }
// }
// choice[s,,1,v] ~ bernoulli_logit(beta_1_MF[s,v]*Q_TD_diff[s,v]
// +beta_1_MB[s,v]*Q_MB_diff[s,v] +pers[s,v]*prev_choice[s,v]);
// choice[s,,2,v] ~ bernoulli_logit(beta_2[s,v]*Q_2_diff[s,v]);
// }
//
// }
}
generated quantities {
//same code as above, with following changes:
// 1) values and choices used to calculate probability, rather than fitting values to choices
// 2) no priors, etc.- uses estimated pararamter values from model block
real log_lik[nS,nT,2,nV]; //log likelihood- must be named this
int prev_choice;
int tran_count;
int tran_type[2];
int unc_state;
real Q_TD[2];
real Q_MB[2];
real Q_2[2,2];
//REDUCE SUM: add transformed parameters here since no longer defined above
//define transformed parameters
matrix<lower=0,upper=1>[nS,nV] alpha;
matrix[nS,nV] alpha_normal;
matrix[nS,nV] beta_1_MF;
matrix[nS,nV] beta_1_MB;
matrix[nS,nV] beta_2;
matrix[nS,nV] pers;
//create transformed parameters using non-centered parameterization for all
// and logistic transformation for alpha (range: 0 to 1),
// add in subject and visit-level effects as shifts in means
for (s in 1:nS) {
alpha_normal[s,] = alpha_m+alpha_visit_s*alpha_visit_raw[s,] + alpha_subj_s*alpha_subj_raw[s];
beta_1_MF[s,] = beta_1_MF_m + beta_1_MF_visit_s*beta_1_MF_visit_raw[s,] +
beta_1_MF_subj_s*beta_1_MF_subj_raw[s];
beta_1_MB[s,] = beta_1_MB_m + beta_1_MB_visit_s*beta_1_MB_visit_raw[s,] +
beta_1_MB_subj_s*beta_1_MB_subj_raw[s];
beta_2[s,] = beta_2_m + beta_2_visit_s*beta_2_visit_raw[s,] + beta_2_subj_s*beta_2_subj_raw[s];
pers[s,] = pers_m + pers_visit_s*pers_visit_raw[s,] + pers_subj_s*pers_subj_raw[s];
//transform alpha to [0,1]
alpha[s,] = inv_logit(alpha_normal[s,]);
}
for (s in 1:nS) {
for (v in 1:nV) {
for (i in 1:2) {
Q_TD[i]=.5;
Q_MB[i]=.5;
Q_2[1,i]=.5;
Q_2[2,i]=.5;
tran_type[i]=0;
}
prev_choice=0;
for (t in 1:nT) {
if (missing_choice[s,t,1,v]==0) {
log_lik[s,t,1,v] = bernoulli_logit_lpmf(choice[s,t,1,v] | beta_1_MF[s,v]*
(Q_TD[2]-Q_TD[1])+beta_1_MB[s,v]*(Q_MB[2]-Q_MB[1])+pers[s,v]*prev_choice);
prev_choice = 2*choice[s,t,1,v]-1; //1 if choice 2, -1 if choice 1
log_lik[s,t,2,v] = bernoulli_logit_lpmf(choice[s,t,2,v] | beta_2[s,v]*
(Q_2[state_2[s,t,v],2]-Q_2[state_2[s,t,v],1]));
//update transition counts: if choice=0 & state=1, or choice=1 & state=2,
//update 1st expectation of transition, otherwise update 2nd expectation
tran_count = (state_2[s,t,v]-choice[s,t,1,v]-1) ? 2 : 1;
tran_type[tran_count] = tran_type[tran_count] + 1;
//update chosen values
Q_TD[choice[s,t,1,v]+1] = Q_TD[choice[s,t,1,v]+1]*(1-(alpha[s,v])) + reward[s,t,v];
Q_2[state_2[s,t,v],choice[s,t,2,v]+1] = Q_2[state_2[s,t,v],choice[s,t,2,v]+1]*
(1 -(alpha[s,v])) + reward[s,t,v];
//update unchosen TD & second stage values
Q_TD[(choice[s,t,1,v] ? 1 : 2)] = (1-alpha[s,v])*Q_TD[(choice[s,t,1,v] ? 1 : 2)];
Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)] = (1-alpha[s,v])*
Q_2[state_2[s,t,v],(choice[s,t,2,v] ? 1 : 2)];
unc_state = (state_2[s,t,v]-1) ? 1 : 2;
Q_2[unc_state,1] = (1-alpha[s,v])*Q_2[unc_state,1];
Q_2[unc_state,2] = (1-alpha[s,v])*Q_2[unc_state,2];
//update model-based values
Q_MB[1] = (tran_type[1] > tran_type[2]) ? (.7*fmax(Q_2[1,1],Q_2[1,2]) +
.3*fmax(Q_2[2,1],Q_2[2,2])) : (.3*fmax(Q_2[1,1],Q_2[1,2]) +
.7*fmax(Q_2[2,1],Q_2[2,2]));
Q_MB[2] = (tran_type[1] > tran_type[2]) ? (.3*fmax(Q_2[1,1],Q_2[1,2]) +
.7*fmax(Q_2[2,1],Q_2[2,2])) : (.7*fmax(Q_2[1,1],Q_2[1,2]) +
.3*fmax(Q_2[2,1],Q_2[2,2]));
} else { //if missing 1st stage choice: decay all TD & 2nd stage values &
//update previous choice
prev_choice=0;
log_lik[s,t,1,v] = 0;
log_lik[s,t,2,v] = 0;
Q_TD[1] = (1-alpha[s,v])*Q_TD[1];
Q_TD[2] = (1-alpha[s,v])*Q_TD[2];
Q_2[1,1] = (1-alpha[s,v])*Q_2[1,1];
Q_2[1,2] = (1-alpha[s,v])*Q_2[1,2];
Q_2[2,1] = (1-alpha[s,v])*Q_2[2,1];
Q_2[2,2] = (1-alpha[s,v])*Q_2[2,2];
}
}
}
}
}