Hi all! I’m new to Stan (and reinforcement learning) and am trying to fit a hierarchical mixture of agents model to 9-10 sessions of behavior in a two armed bandit task for each of 3 rats. I’ve played around with the number of iterations and have cranked adapt_delta up to 0.99, but I’m still getting the “divergent transitions after warmup” error. I think I may need to reparameterize the model but I’m not sure how to do that. Any help would be very much appreciated!
The model code:
data {
int n_rats; // number of subjects
int max_days; // maximum days per subject
int max_trials; // maximum trials per day
int num_days[n_rats]; // actual numbers of days per subject
int num_trials[n_rats,max_days]; // actual numbers of trials per sub/day
int r[n_rats,max_days,max_trials]; // -1 / 1
int choice[n_rats,max_days,max_trials]; // L vs R poke: 2 is right 1 is left
real click_diff[n_rats,max_days,max_trials]; // click difference (+ means more on right)
int prev_choice[n_rats,max_days,max_trials]; // 2 is right 1 is left
}
// the model parameters
parameters {
// group level
vector[5] betam; // group level mean parameters
// rat level
vector[5] betas[n_rats]; // per subject mean parameters
vector<lower=0>[5] sigmaS; // parameter STDs across subs
// session level
vector[5] betad[sum(num_days)]; // per day mean parameters
vector<lower=0>[5] sigmaD; // param STDs across days w/in subject
}
// the model itself
model {
int dayindex;
// priors
sigmaS ~ normal(0,1);
sigmaD ~ normal(0,1);
betam ~ normal(0,1);
dayindex = 0;
// loop over subjects
for (s in 1:n_rats) {
// each subject's parameters
// betas[s] ~ multi_normal(betam, diag_matrix(sigmaS));
betas[s] ~ normal(betam, sigmaS);
// loop over days within subject
for (d in 1:num_days[s]) {
real betaQ; # beta q value
real betaP; # beta perseveration
real betaC; # beta clicks
real bias; # overall bias left or right
real alpha; # learning rate
// -----------------------------------------
vector[2] q; // value of each option
vector[2] qeff; // this is the sum of all the "agents"
vector[2] P;
vector[2] C;
// ----------------------------------------------------
dayindex += 1;
// draw this day's parameters
betad[dayindex] ~ normal(betas[s], sigmaD);
// unpack the 5-vector into the 5 params
betaQ = betad[dayindex,1];
betaP = betad[dayindex,2];
betaC = betad[dayindex,3];
bias = betad[dayindex,4];
alpha = Phi_approx(betad[dayindex,5]/1.7); // 1.7 is sqrt(3), makes prior ~uniform
// initialize q values for left (1) and right (2) choices
q[1] = 0;
q[2] = 0;
// ----------------------------------------
P[1] = 0;
P[2] = 0;
C[1] = 0;
C[2] = 0;
// loop over trials
for (i in 1:num_trials[s,d]) {
// calculate C
C[1] = -click_diff[s,d,i]; # index 1 is negative because its leftward choices
C[2] = click_diff[s,d,i];
// calculate P
P[prev_choice[s,d,i]] = 1;
P[3 - prev_choice[s,d,i]] = -1;
qeff[1] = betaQ * q[1] + betaP * P[1] + betaC * C[1] -bias ; # negative bias parameter will be leftward bias
qeff[2] = betaQ * q[2] + betaP * P[2] + betaC * C[2] +bias;
// probability of choice
choice[s,d,i] ~ categorical_logit(qeff);
// update Q values from outcome
q[choice[s,d,i]] = q[choice[s,d,i]] + alpha * (r[s,d,i] - q[choice[s,d,i]]);
}
}
}
}