functions { real partial_sum(real[] dummy, int start, int end, real[,] X, vector mu_prior_x, real sigma_v, matrix A, matrix B, real[,,] U, real Q, matrix C, real sigma_r, int[,,] pressed_gng, int[,,] cue_gng, real[,,] outcome_gng, int[,] Tr_gng, int[] idx_gng_obs, real[] b_gng_pr, real[] pi_gng_pr, real[] xi_gng_pr, real[] ep_gng_pr, real[] rho_gng_pr, real[] b_gng, real[] pi_gng, real[] xi_gng, real[] ep_gng, real[] rho_gng, vector initV_gng, int Bl) { real lt = 0; if (w == 1) { lt += normal_lpdf(to_vector(X[start,])|mu_prior_x,sigma_v); } else { lt += normal_lpdf(to_vector(X[start,])|(A * to_vector(X[start,]) + B * to_vector(U[start,w-1,])), Q); } lt += normal_lpdf(xi_gng_pr[start]|C[1,] * to_vector(X[start]),sigma_r); lt += normal_lpdf(ep_gng_pr[start]|C[2,] * to_vector(X[start]),sigma_r); lt += normal_lpdf(b_gng_pr[start]|C[3,] * to_vector(X[start]),sigma_r); lt += normal_lpdf(pi_gng_pr[start]|C[4,] * to_vector(X[start]),sigma_r); lt += normal_lpdf(rho_gng_pr[start]|C[5,] * to_vector(X[start]),sigma_r); if (idx_gng_obs[start] != 0) { vector[4] wv_g; // action weight for go vector[4] wv_ng; // action weight for nogo vector[4] qv_g; // Q value for go vector[4] qv_ng; // Q value for nogo vector[4] sv; // stimulus value vector[4] pGo; // prob of go (press) for (bl in 1:Bl) { wv_g = initV_gng; wv_ng = initV_gng; qv_g = initV_gng; qv_ng = initV_gng; sv = initV_gng; for (t in 1:Tr_gng[start,bl]) { wv_g[cue_gng[start,bl, t]] = qv_g[cue_gng[start,bl, t]] + b_gng[start] + pi_gng[start] * sv[cue_gng[start,bl, t]]; wv_ng[cue_gng[start,bl, t]] = qv_ng[cue_gng[start,bl, t]]; // qv_ng is always equal to wv_ng (regardless of action) pGo[cue_gng[start,bl, t]] = inv_logit(wv_g[cue_gng[start,bl, t]] - wv_ng[cue_gng[start,bl, t]]); { // noise pGo[cue_gng[start,bl, t]] *= (1 - xi_gng[start]); pGo[cue_gng[start,bl, t]] += xi_gng[start]/2; } lt += bernoulli_lpmf(pressed_gng[start,bl, t]|pGo[cue_gng[start,bl, t]]); // after receiving feedback, update sv[t + 1] sv[cue_gng[start,bl, t]] += ep_gng[start] * (rho_gng[start] * outcome_gng[start,bl, t] - sv[cue_gng[start,bl, t]]); // update action values if (pressed_gng[start,bl, t]) { // update go value qv_g[cue_gng[start,bl, t]] += ep_gng[start] * (rho_gng[start] * outcome_gng[start,bl, t] - qv_g[cue_gng[start,bl, t]]); } else { // update no-go value qv_ng[cue_gng[start,bl, t]] += ep_gng[start] * (rho_gng[start] * outcome_gng[start,bl, t] - qv_ng[cue_gng[start,bl, t]]); } } // end of t loop } // end of b loop } return lt; } } data { int W; int N; int Xdim; int exo_q_num; real U[N,W,exo_q_num]; // Go/NoGo int idx_gng_obs[N,W]; int P_gng; int Bl; int T_max_gng; int Tr_gng[N,W,Bl]; int cue_gng[N,W,Bl,T_max_gng]; int pressed_gng[N,W,Bl, T_max_gng]; real outcome_gng[N, W, Bl, T_max_gng]; } transformed data { real dummy[N]; real Q = 1; int cauchy_alpha = 5; int num_par = P_gng; int grainsize = 1; vector[4] initV_gng = rep_vector(0.0, 4); } parameters { real sigma_x; real sigma_v; real sigma_r; real sigma_a; real sigma_b; real sigma_c; vector[Xdim] mu_prior_x; real X[N,W,Xdim]; matrix[Xdim, Xdim] A; matrix[Xdim, exo_q_num] B; matrix[num_par, Xdim] C; real xi_gng_pr[N,W]; real ep_gng_pr[N,W]; real b_gng_pr[N,W]; real pi_gng_pr[N,W]; real rho_gng_pr[N,W]; } transformed parameters { real xi_gng[N,W]; real ep_gng[N,W]; real b_gng[N,W]; real pi_gng[N,W]; real rho_gng[N,W]; for (n in 1:N) { rho_gng[n,] = exp(rho_gng_pr[n,]); } for (n in 1:N) { for (w in 1:W) { xi_gng[n,w] = Phi_approx(xi_gng_pr[n,w]); ep_gng[n,w] = Phi_approx(ep_gng_pr[n,w]); } } b_gng = b_gng_pr; pi_gng = pi_gng_pr; } model { sigma_x ~ cauchy(0, cauchy_alpha); // prior on R diagonal sigma_v ~ cauchy(0, cauchy_alpha); // prior on R diagonal sigma_r ~ cauchy(0, cauchy_alpha); // prior on R diagonal sigma_a ~ cauchy(0, cauchy_alpha); // prior on A variance sigma_b ~ cauchy(0, cauchy_alpha); // prior on B variance sigma_c ~ cauchy(0, cauchy_alpha); // prior on C variance mu_prior_x ~ normal(0,sigma_x); // prior on X mean // put priors an A, B, C to_vector(A) ~ normal(0,sigma_a); to_vector(B) ~ normal(0,sigma_b); to_vector(C) ~ normal(0,sigma_c); for (w in 1:W) { target += reduce_sum_static(partial_sum, dummy, grainsize, X[,w,], mu_prior_x, sigma_v, A, B, U, Q, C, sigma_r, pressed_gng[,w,,], cue_gng[,w,,], outcome_gng[,w,,], Tr_gng[,w,], idx_gng_obs[,w], b_gng_pr[,w], pi_gng_pr[,w], xi_gng_pr[,w], ep_gng_pr[,w], rho_gng_pr[,w], b_gng[,w], pi_gng[,w], xi_gng[,w], ep_gng[,w], rho_gng[,w], initV_gng, Bl); } } """