RL Model- extremely slow gradient evalution

Hello,

When I run a fairly simple RL model (albeit with a lot of data ~18000 trials), gradient evaluation takes about 1.5 seconds and the actual sampling takes many hours. I’m wondering if the problem is the large data set or my implementation? Does anyone have any recommendation for how to make it more efficient?

note- this is the case when i choose both the lognormal and the gamma priors for the forgetting parameter.

data {
  int Nsubjs;
  int Trial[Nsubjs]; 
  int outcome[sum(Trial)];
  int choice[sum(Trial)];
  int stim1[sum(Trial)];
  int stim2[sum(Trial)];
  vector[sum(Trial)] feed; 
  int nftrials[Nsubjs];
  int ind[Nsubjs];
  int<lower=0,upper=1> use_lognormal;
  int subject [sum(Trial)]; 
  vector[sum(Trial)] times; 
  int<lower=0, upper=1> current_choice[sum(Trial)] ;  
  int Nstim;
  int<lower=0,upper=1> verbose;
  vector[sum(Trial)] correct; 
  matrix[Nstim, Nsubjs] reward;

  
}

transformed data {
  vector[sum(Trial)] time_passed;
  for(t in 1:sum(Trial)){
    int n = subject[t];
    if(t>1 && subject[t]==subject[t-1]){
      time_passed[t]=(times[t]-times[t-1])/(24*3600*1000);
    } else {
      time_passed[t]=0;
    }
  }
}

parameters {
  // hyperparameters for alpha
  
  real<lower=0,upper=1>  w; 
  real<lower=0> kap; 
  
  real<lower=0,upper=1>  w_default; 
  real<lower=0> kap_default; 
  
  //per subject parameter 
  real<lower=0,upper=1> alpha_fast [Nsubjs];
  real<upper=0> alpha_rescale;
  
  real<lower=0,upper=1> default_parameter [Nsubjs];
  
  // gamma for lambda
  real<lower=0> lambdafast_shape[use_lognormal ? 0:1];
  real<lower=0> lambdafast_rate [use_lognormal ? 0:1];
  // for lognormal
  real mu_lambda_fast [use_lognormal ? 1:0];
  real <lower=0> sigma_lambda_fast[use_lognormal ? 1:0];
  
  real <lower=0> lambda_fast [Nsubjs];
  real lambda_rescale; 
  
  // beta slow
  real<lower=0> shape_beta_slow[use_lognormal ? 0:1];
  real<lower=0>   rate_beta_slow[use_lognormal ? 0:1];
  real mu_beta_slow[use_lognormal ? 1:0]; 
  real<lower=0> sigma_beta_slow[use_lognormal ? 1:0];
  
  vector<lower=0>[Nsubjs] beta_slow;
  //beta fast
  real mu_beta_fast[use_lognormal ? 1:0];
  real<lower=0> sigma_beta_fast[use_lognormal ? 1:0]; 
  real<lower=0> shape_beta_fast[use_lognormal ? 0:1];
  real<lower=0> rate_beta_fast [use_lognormal ? 0:1];
  
  vector<lower=0>[Nsubjs] beta_fast;
  
  
  
}
transformed parameters{
  vector<lower=0,upper=1>[Nsubjs] alpha_slow; 
  vector[Nsubjs] lambda_slow;
  for(n in 1:Nsubjs){
    alpha_slow[n]= alpha_fast[n]*exp(alpha_rescale);
    lambda_slow[n]= lambda_fast[n]*lambda_rescale;
  }  
}

model {
  vector[2] current_fast =rep_vector(0,2);
  vector[2] current_slow =rep_vector(0,2);
  real a; 
  real b; 
  real k; 
  real a_default; 
  real b_default; 
  real k_default;
  real PE_slow=0;
  real PE_fast=0; 
  vector [sum(Trial)] deltaV=rep_vector(0,sum(Trial)); 
  vector[Nstim] ev_fast=rep_vector(0,Nstim);  
  vector[Nstim] ev_slow=rep_vector(0,Nstim);

    //alpha
    target +=  beta_lpdf(w|1,1);
    target +=  gamma_lpdf(kap|1,1);
    k= kap+2;
    a=w*(k-2)+1;
    b=(1-w)*(k-2)+1;
    
    target +=  beta_lpdf(w_default|1,1);
    target +=  gamma_lpdf(kap_default|1,1);
    k_default= kap_default+2;
    a_default=w_default*(k_default-2)+1;
    b_default=(1-w_default)*(k_default-2)+1;
    
    
    //rescaling
    target += normal_lpdf(alpha_rescale |0,5);
    target += normal_lpdf(lambda_rescale |0,10);
    
    
    if(!use_lognormal){
      target += gamma_lpdf(lambdafast_shape|1,1);
      target += gamma_lpdf(lambdafast_rate |1,1);
      
      target += gamma_lpdf(shape_beta_fast|1,1);
      target += gamma_lpdf(rate_beta_fast |1,1);
      
      target += gamma_lpdf(shape_beta_slow|1,1);
      target += gamma_lpdf(rate_beta_slow |1,1);
      
      
      for(i in 1:Nsubjs){
        target += beta_lpdf(alpha_fast[i] | a,b);
        target += beta_lpdf(default_parameter[i]|a_default,b_default);
        target += gamma_lpdf(lambda_fast[i]|lambdafast_shape,lambdafast_rate);
        target += gamma_lpdf(beta_fast[i]|shape_beta_fast,rate_beta_fast);
        target += gamma_lpdf(beta_slow[i]|shape_beta_slow,rate_beta_slow);
      }
      
    }
    else{
      
      target += normal_lpdf(mu_lambda_fast|0,3);
      target += cauchy_lpdf(sigma_lambda_fast| 0,2.5);
      
      target += normal_lpdf(mu_beta_slow|0,3);
      target +=  cauchy_lpdf(sigma_beta_slow |0,2.5);
      
      target +=  normal_lpdf(mu_beta_fast |0,3);
      target +=  cauchy_lpdf(sigma_beta_fast| 0,2.5);
      
      for(i in 1:Nsubjs){
        target += lognormal_lpdf(lambda_fast[i] |mu_lambda_fast,sigma_lambda_fast);
        target += lognormal_lpdf(beta_fast[i]|mu_beta_fast,sigma_beta_fast);
        target +=  lognormal_lpdf(beta_slow[i] |mu_beta_slow,sigma_beta_slow);
        target += beta_lpdf(alpha_fast[i] | a,b);
        target += beta_lpdf(default_parameter[i]|a_default,b_default);
      }
    }
    
    
    
    for (t in 1:sum(Trial)){
      int n = subject[t];
      if(t>1 && subject[t]==subject[t-1])  {
        ev_fast = (ev_fast-.5)* exp(-lambda_fast[n]*time_passed[t])+.5;
        ev_slow = (ev_slow-default_parameter[n])* exp(-lambda_slow[n]*time_passed[t])+default_parameter[n];
      } else{
        ev_fast=rep_vector(.5,Nstim);
        ev_slow=rep_vector(.5,Nstim);
      }
      
      current_slow[1] = ev_slow[stim1[t]+1];
      current_slow[2] = ev_slow[stim2[t]+1];
      current_fast[1] = ev_fast[stim1[t]+1];
      current_fast[2] = ev_fast[stim2[t]+1];
      
      deltaV[t]  = beta_fast[n]*(current_fast[2] - current_fast[1]) + beta_slow[n]*(current_slow[2] - current_slow[1]);
      
      if(feed[t]==1){
        PE_fast = outcome[t] - ev_fast[choice[t]+1];
        PE_slow = outcome[t] - ev_slow[choice[t]+1];
        ev_slow[choice[t]+1] +=  alpha_slow[n] * PE_slow;
        ev_fast[choice[t]+1] +=  alpha_fast[n] * PE_fast;
      }
    }
    
    current_choice ~ bernoulli_logit(deltaV);
  
}
1 Like

Sounds like a lot of data… if you have multiple cores available to use per chain and are willing to use cmdstan, then you could try reduce_sum here.

1 Like

When you test the model on a subset of data, is it giving values for EV and PE that you would expect? Are you able to simulate reasonable behavior?

I’m not sure what’s going on with the ev_* indexing - I may just need to look at the code more closely, but in your model block it looks like for each trial you are 1) manipulating the whole ev_* vector by subtracting a constant/parameter and multiplying by another set of constants and the time passed, 2) using trial t+1’s ev_* value to define the current value to fit to choices, then 3) adding alpha * PE to trial t+1’s ev_* value?

It’s a lot of data, but IME 1.5 seconds for a gradient evaluation is a lot for an RL model.

1 Like

Hi Vanessa,

  1. in terms of the EV, PE and reasonable behavior it does look more or less reasonable.

  2. Yes for each trial, a whole ev vector (105 stim) inches closer to a constant/parameter. Then fit to a choice and based on the outcome of that choice updates ev based on the prediction error.

And I agree with you. especially since when I took out half the parameters (instead of two alphas i have one etc…) it still takes a crazy number of hours to run. Furthermore, in the past on the same model it generally ran fairly quickly.

Apart from the model issues
Is it possible that models will run slower depending on the 1) the rstan version im using a 2) power usage issue? I constantly get “very high” power usage from Rstudio.

hmm, so you have run the same code & data on earlier versions of rstan and it did not take this long? I don’t have a ton of expertise in stan run time issues - maybe @mike-lawrence or others could help? If it has always taken this long (or if this is a new model), I would suggest it to be an issue with your RL code, but if I understand correctly that doesn’t seem to be the issue.

1 Like

Hm, I definitely don’t have any expertise in this realm specifically, so while I’ve looked at the code, it’s not immediately obvious to me what’s going on. So I’ll simply interject my standard recommendation to look for redundancy in the data input to the model (esp. design matrices), which usually means that there are redundant computations being done that can be avoided by identifying unique entries and indexing as shown here. Since it’s a binomial outcome, you should look at whether there’s any chunks of data for which you can predict that deltaV is constant (you don’t need to know what the constant value actually will be, just that it will be the same value for all the elements in the chunk), in which case you can speed things up substantially by using sufficient statistics.

2 Likes