Asking for advice on making a more efficient RL code

Hi Everyone,

I am new to stan and trying to bring the magic to my team. We wrote a simple code to estimate two parameters for a Q-learner. Here, a player is asked to chose between 1 to N cards (or arms), each leading to reward on some drifting probability. The model has two parameters (learning rate and a softmax temperature). While the model and parameters recover well (>.9 corr between recovered and true parameters), we need to fit this to big data, and it seem to be painfully slow (up to 15 hours or more).

Would you have any advice on how to write a better optimized code? Is there anything I can code differently that will help this run faster? Any advice would be really fantastic. We got so far on our own, but I can’t really find a way to get this to be more efficient.

Thanks a million,
Nitzan


  
  //General fixed parameters for the experiment/models
  int<lower = 1> Nsubjects;           //number of subjects
  int<lower = 1> Ntrials;             //number of trials without missing data, typically the maximum of trials per subject
  int<lower = 1> Ntrials_per_subject[Nsubjects];  //number of trials left for each subject after data omission

  //Behavioral data
  int<lower = 0> action[Nsubjects,Ntrials];        //index of which arm was pulled coded 1 to 4
  int<lower = 0> reward[Nsubjects,Ntrials];            //outcome of bandit arm pull

}
transformed data{
    int<lower = 1> Nparameters; //number of parameters
    int<lower = 2> Narms;       //number of overall alternatives

    Nparameters=2;
    Narms      =4;
}

parameters {

  //population level parameters 
  vector[Nparameters] mu;                    //vector with the population level mean for each model parameter
  vector<lower=0>[Nparameters] tau;          //vector of random effects variance for each model parameter
  cholesky_factor_corr[Nparameters] L_Omega; //lower triangle of a correlation matrix to be used for the random effect of the model parameters
  
  //subject level parameters
  vector[Nparameters] auxiliary_parameters[Nsubjects]; 
}


transformed parameters {

      //population level
      matrix[Nparameters,Nparameters] sigma_matrix;
      
      //individuals level
      real alpha[Nsubjects];
      real beta[Nsubjects];

      //additional variabels
      real  log_lik[Nsubjects,Ntrials];
      vector<lower=0, upper=1>[Narms] Qcard;

     //pre-assignment
      sigma_matrix = diag_pre_multiply(tau, (L_Omega*L_Omega')); //L_Omega*L_omega' give us Omega (the corr matrix). 
      sigma_matrix = diag_post_multiply(sigma_matrix, tau);     // diag(tau)*omega*diag(tau) gives us sigma_matirx (the cov matrix)
      
      log_lik=rep_array(0,Nsubjects,Ntrials);
      
      
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  for (subject in 1:Nsubjects){

        alpha[subject]              = inv_logit(auxiliary_parameters[subject][1]);
        beta[subject]               = exp(auxiliary_parameters[subject][2]);
        
        //pre-assignment of Qvalues
        for (a in 1:Narms)   Qcard[a]    = 0;

        //trial by trial loop        
        for (trial in 1:Ntrials_per_subject[subject]){
            
            //liklihood function (softmax)
            log_lik[subject,trial]=log_softmax(Qcard*beta[subject])[action[subject,trial]];

            //Qvalues update
            Qcard[action[subject,trial]] += alpha[subject] * (reward[subject,trial] - Qcard[action[subject,trial]]);
            
        } 
  }
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
}
model {
  
  // population level priors (hyper-parameters)
  mu  ~ normal(0, 5);             // mu is a vector 1xNparameters with the population mean (i.e., location) for each model parameter
  tau ~ cauchy(0, 1);             //tau is the hyperparameters variance vector
  L_Omega ~ lkj_corr_cholesky(2); //L_omega is the lower triangle of the correlations. Setting the lkj prior to 2 means the off-diagonals are priored to be near zero

  // indvidual level priors (subject parameters)
  auxiliary_parameters ~ multi_normal(mu, sigma_matrix);

  target += sum(log_lik);

}

1 Like

Hi,
when a model is slow, the first thing to check is whether zou are not seeing any problems with the model (bad fit to data, bad diagnostics, weird structures on the pairs plot, …), as models that behave badly are often very slow.

That’s a good first check, but it is often useful to also check coverage - e.g. do x% posterior intervals include the true values roughly x% of the time for various x? (this is then related to SBC - simulation based calibration).

If there are no problems with the model itself the case, there is a bunch of stuff you can try to gain speed-ups the “hard way”:

  • since you seem to already have the cholesky decomposition of sigma_matrix, using multi_normal_cholesky is preferred (multi_normal will decompose the matrix internally)
  • You seem to be calculating sigma_matrix twice in the code
  • You could use non-centered parametrization for auxiliary_parameters with something like (I didn’t test the code, but I hope the general idea is clear):
parameters {
  matrix[Nsubjects, Nparameters] z;
}

transformed parameters {
  auxiliary_parameters = diag_pre_multiply(tau, L_Omega) * z;
}

model {
  to_vector(z) ~ std_normal();
}

  • Since the calculation decomposes neatly by participant, you may benefit from using reduce_sum and use multiple cores per chain.
  • The log_lik matrix is likely quite big so it is possible storing it is expensive. Moving the calculation to model block and incrementing target instead of storing could help.
  • The cauchy(0,1) prior is likely more heavy-tailed than you want - maybe there is domain knowledge to restrict the a-priori between-subject variances a bit more?
  • In log_softmax(Qcard*beta[subject])[action[subject,trial]] you compute the whole softmax vector but then use just one element of it. A minor speedup is likely possible with (please double check I didn’t mess up the simplification):
vector[narms] Qscaled = Qcard*beta[subject];
log_lik[subject,trial] = Qscaled[action[subject,trial]] - log_sum_exp(Qscaled);

In all cases, it is useful to time your model a few times to see if you are actually making progress. Also, since 2.26, Stan supports profiling, so you can see how costly are individual sections of your code. It definitely makes no sense to optimize the parts that are already fast: see e.g. Profiling Stan programs with CmdStanR • cmdstanr

If you believe the posterior is well approximated by a multivariate normal, you may also try using the built-in ADVI, but you should definitely check if the results don’t change much between ADVI and MCMC for datasets where MCMC is feasible first.

Hope at least some of this will help.

1 Like

Martin,

This is extremely helpful - thank you for taking the time.
One follow up please:

Since the calculation decomposes neatly by participant, you may benefit from using reduce_sum and use multiple cores per chain.

How would you go about on implementing something like this? If I understand correctly, reduce_sum will automatically partition the data (which might be problematic here, since every subject needs to be estimated across all trials). Do you think it would make sense to try " reduce_sum_static" or did you have a different way in mind?

I promise to update, and I can say that already moving the log_lik made the process about twice faster.

Thanks,
Nitzan.

I am on a cellphone, so cannot Google very well, but there was a very nice tutorial for reduce_sum,I think by Sebastian Weber (@wds15 ) that should cover the basics. The main point is to find a way to split your data and parameters in blocks that don’t have big overlaps, to avoid the need to transfer the same data/parameters to multiple threads. In your case, it appears that the log likelihood per participant can be computed independently of other participants and should thus naturally fit the reduce_sum paradigm.

If you are stuck finding the instructions or following them, feel free to ask for clarifications.

1 Like

To follow up in case other users will want to work with RL Q learning code in stan - hBayesDM has really fantastic implementations and I have personally learned a lot from reading into their RL code

2 Likes