Welcome to the community!
To accompany @mingqian.guo 's more technical feedback, here are a few practical suggestions–from someone who is not familiar with these sorts of models.
First, could you share the data generation code? That might reveal some discrepancies between the data generation and model.
Second, I would suggest you start by simplifying the model code as much as possible to make it clearer where any issues may be arising. For example, it looks like the data are…
- replays…
- nested in trials…
- nested in sessions…
- nested in subjects.
It looks like you are already analyzing only the first subject by hard-coding 1 in several places (e.g. c[1,ss,t] ~ categorical_logit(3 * Q);
). Could you start with a model that just looks at one subject and one session, iterating over trials and replays? I’ve tried this below, but double-check my work!
Third, the problem might be clearer if you’re able to see how Q
changes across replays and trials. I’ve tried to do that in the code below (again, check my work).
Good luck!
data {
int<lower = 1> NT; // Number of trials
int<lower = 1> max_NR; // max number of replays across all trials
array[NT] int<lower = 1, upper = max_NR> NR; // NR number of replays
int<lower = 1> N_ACTIONS;
array[NT,max_NR] int<lower = 1, upper = N_ACTIONS> rp_arms; // replayed arms {1,...,N_ACTIONS}
array[NT,max_NR] int rp_rwd; // rwds assigned to replayed arms
array[NT] int<lower = 1, upper = N_ACTIONS> c; // arm choices {1,...,N_ACTIONS}
array[NT] int<lower = 0, upper = 1> r; // reward {0,1}
}
transformed data{
int NR_total = 0; // Total number of replays
int NQ; // Number of sets of Q
for(t in 1:NT){
NR_total += NR[t];
}
NQ = NR_total + 1 + NT;
}
parameters {
real alphaRm;
}
model {
alphaRm ~ normal(0,1);
vector[8] Q = rep_vector(0, 8); // Initialize Q-values for this subject with zero
real alphaR;
alphaR = Phi_approx(alphaRm)-0.5;
for (t in 1:NT) { // Loop over trials
// Choice (softmax)
c[t] ~ categorical_logit(3 * Q); // fixed beta=3 TO MAKE IT EASIER
for(rp_i in 1:NR[t]){
Q[rp_arms[t,rp_i]] += alphaR * (rp_rwd[t,rp_i] - Q[rp_arms[t,rp_i]]);
}
// Q-learning
Q[c[t]] += 0.5 * (r[t] - Q[c[t]]); // fixed alphaD=0.5 TO MAKE IT EASIER
}
}
generated quantities {
real alphaRm_phied;
array[NQ] vector[8] Q_set;
array[NQ] vector[8] P_set; // Probability of selecting each
alphaRm_phied = Phi_approx(alphaRm)-0.5;
{
vector[8] Q = rep_vector(0, 8);
int count = 1;
real alphaR;
alphaR = Phi_approx(alphaRm)-0.5;
Q_set[1] = Q;
for (t in 1:NT) { // Loop over trials
for(rp_i in 1:NR[t]){
Q[rp_arms[t,rp_i]] += alphaR * (rp_rwd[t,rp_i] - Q[rp_arms[t,rp_i]]);
count += 1;
Q_set[count] = Q;
}
// Q-learning
Q[c[t]] += 0.5 * (r[t] - Q[c[t]]); // fixed alphaD=0.5 TO MAKE IT EASIER
count += 1;
Q_set[count] = Q;
}
}
for(n in 1:NQ){
P_set[n] = softmax(Q_set[n]);
}
}