Hello,
When I run a fairly simple RL model (albeit with a lot of data ~18000 trials), gradient evaluation takes about 1.5 seconds and the actual sampling takes many hours. I’m wondering if the problem is the large data set or my implementation? Does anyone have any recommendation for how to make it more efficient?
note- this is the case when i choose both the lognormal and the gamma priors for the forgetting parameter.
data {
int Nsubjs;
int Trial[Nsubjs];
int outcome[sum(Trial)];
int choice[sum(Trial)];
int stim1[sum(Trial)];
int stim2[sum(Trial)];
vector[sum(Trial)] feed;
int nftrials[Nsubjs];
int ind[Nsubjs];
int<lower=0,upper=1> use_lognormal;
int subject [sum(Trial)];
vector[sum(Trial)] times;
int<lower=0, upper=1> current_choice[sum(Trial)] ;
int Nstim;
int<lower=0,upper=1> verbose;
vector[sum(Trial)] correct;
matrix[Nstim, Nsubjs] reward;
}
transformed data {
vector[sum(Trial)] time_passed;
for(t in 1:sum(Trial)){
int n = subject[t];
if(t>1 && subject[t]==subject[t-1]){
time_passed[t]=(times[t]-times[t-1])/(24*3600*1000);
} else {
time_passed[t]=0;
}
}
}
parameters {
// hyperparameters for alpha
real<lower=0,upper=1> w;
real<lower=0> kap;
real<lower=0,upper=1> w_default;
real<lower=0> kap_default;
//per subject parameter
real<lower=0,upper=1> alpha_fast [Nsubjs];
real<upper=0> alpha_rescale;
real<lower=0,upper=1> default_parameter [Nsubjs];
// gamma for lambda
real<lower=0> lambdafast_shape[use_lognormal ? 0:1];
real<lower=0> lambdafast_rate [use_lognormal ? 0:1];
// for lognormal
real mu_lambda_fast [use_lognormal ? 1:0];
real <lower=0> sigma_lambda_fast[use_lognormal ? 1:0];
real <lower=0> lambda_fast [Nsubjs];
real lambda_rescale;
// beta slow
real<lower=0> shape_beta_slow[use_lognormal ? 0:1];
real<lower=0> rate_beta_slow[use_lognormal ? 0:1];
real mu_beta_slow[use_lognormal ? 1:0];
real<lower=0> sigma_beta_slow[use_lognormal ? 1:0];
vector<lower=0>[Nsubjs] beta_slow;
//beta fast
real mu_beta_fast[use_lognormal ? 1:0];
real<lower=0> sigma_beta_fast[use_lognormal ? 1:0];
real<lower=0> shape_beta_fast[use_lognormal ? 0:1];
real<lower=0> rate_beta_fast [use_lognormal ? 0:1];
vector<lower=0>[Nsubjs] beta_fast;
}
transformed parameters{
vector<lower=0,upper=1>[Nsubjs] alpha_slow;
vector[Nsubjs] lambda_slow;
for(n in 1:Nsubjs){
alpha_slow[n]= alpha_fast[n]*exp(alpha_rescale);
lambda_slow[n]= lambda_fast[n]*lambda_rescale;
}
}
model {
vector[2] current_fast =rep_vector(0,2);
vector[2] current_slow =rep_vector(0,2);
real a;
real b;
real k;
real a_default;
real b_default;
real k_default;
real PE_slow=0;
real PE_fast=0;
vector [sum(Trial)] deltaV=rep_vector(0,sum(Trial));
vector[Nstim] ev_fast=rep_vector(0,Nstim);
vector[Nstim] ev_slow=rep_vector(0,Nstim);
//alpha
target += beta_lpdf(w|1,1);
target += gamma_lpdf(kap|1,1);
k= kap+2;
a=w*(k-2)+1;
b=(1-w)*(k-2)+1;
target += beta_lpdf(w_default|1,1);
target += gamma_lpdf(kap_default|1,1);
k_default= kap_default+2;
a_default=w_default*(k_default-2)+1;
b_default=(1-w_default)*(k_default-2)+1;
//rescaling
target += normal_lpdf(alpha_rescale |0,5);
target += normal_lpdf(lambda_rescale |0,10);
if(!use_lognormal){
target += gamma_lpdf(lambdafast_shape|1,1);
target += gamma_lpdf(lambdafast_rate |1,1);
target += gamma_lpdf(shape_beta_fast|1,1);
target += gamma_lpdf(rate_beta_fast |1,1);
target += gamma_lpdf(shape_beta_slow|1,1);
target += gamma_lpdf(rate_beta_slow |1,1);
for(i in 1:Nsubjs){
target += beta_lpdf(alpha_fast[i] | a,b);
target += beta_lpdf(default_parameter[i]|a_default,b_default);
target += gamma_lpdf(lambda_fast[i]|lambdafast_shape,lambdafast_rate);
target += gamma_lpdf(beta_fast[i]|shape_beta_fast,rate_beta_fast);
target += gamma_lpdf(beta_slow[i]|shape_beta_slow,rate_beta_slow);
}
}
else{
target += normal_lpdf(mu_lambda_fast|0,3);
target += cauchy_lpdf(sigma_lambda_fast| 0,2.5);
target += normal_lpdf(mu_beta_slow|0,3);
target += cauchy_lpdf(sigma_beta_slow |0,2.5);
target += normal_lpdf(mu_beta_fast |0,3);
target += cauchy_lpdf(sigma_beta_fast| 0,2.5);
for(i in 1:Nsubjs){
target += lognormal_lpdf(lambda_fast[i] |mu_lambda_fast,sigma_lambda_fast);
target += lognormal_lpdf(beta_fast[i]|mu_beta_fast,sigma_beta_fast);
target += lognormal_lpdf(beta_slow[i] |mu_beta_slow,sigma_beta_slow);
target += beta_lpdf(alpha_fast[i] | a,b);
target += beta_lpdf(default_parameter[i]|a_default,b_default);
}
}
for (t in 1:sum(Trial)){
int n = subject[t];
if(t>1 && subject[t]==subject[t-1]) {
ev_fast = (ev_fast-.5)* exp(-lambda_fast[n]*time_passed[t])+.5;
ev_slow = (ev_slow-default_parameter[n])* exp(-lambda_slow[n]*time_passed[t])+default_parameter[n];
} else{
ev_fast=rep_vector(.5,Nstim);
ev_slow=rep_vector(.5,Nstim);
}
current_slow[1] = ev_slow[stim1[t]+1];
current_slow[2] = ev_slow[stim2[t]+1];
current_fast[1] = ev_fast[stim1[t]+1];
current_fast[2] = ev_fast[stim2[t]+1];
deltaV[t] = beta_fast[n]*(current_fast[2] - current_fast[1]) + beta_slow[n]*(current_slow[2] - current_slow[1]);
if(feed[t]==1){
PE_fast = outcome[t] - ev_fast[choice[t]+1];
PE_slow = outcome[t] - ev_slow[choice[t]+1];
ev_slow[choice[t]+1] += alpha_slow[n] * PE_slow;
ev_fast[choice[t]+1] += alpha_fast[n] * PE_fast;
}
}
current_choice ~ bernoulli_logit(deltaV);
}