Hi Everyone,

I am new to stan and trying to bring the magic to my team. We wrote a simple code to estimate two parameters for a Q-learner. Here, a player is asked to chose between 1 to N cards (or arms), each leading to reward on some drifting probability. The model has two parameters (learning rate and a softmax temperature). While the model and parameters recover well (>.9 corr between recovered and true parameters), we need to fit this to big data, and it seem to be painfully slow (up to 15 hours or more).

Would you have any advice on how to write a better optimized code? Is there anything I can code differently that will help this run faster? Any advice would be really fantastic. We got so far on our own, but I can’t really find a way to get this to be more efficient.

Thanks a million,
Nitzan

``````
//General fixed parameters for the experiment/models
int<lower = 1> Nsubjects;           //number of subjects
int<lower = 1> Ntrials;             //number of trials without missing data, typically the maximum of trials per subject
int<lower = 1> Ntrials_per_subject[Nsubjects];  //number of trials left for each subject after data omission

//Behavioral data
int<lower = 0> action[Nsubjects,Ntrials];        //index of which arm was pulled coded 1 to 4
int<lower = 0> reward[Nsubjects,Ntrials];            //outcome of bandit arm pull

}
transformed data{
int<lower = 1> Nparameters; //number of parameters
int<lower = 2> Narms;       //number of overall alternatives

Nparameters=2;
Narms      =4;
}

parameters {

//population level parameters
vector[Nparameters] mu;                    //vector with the population level mean for each model parameter
vector<lower=0>[Nparameters] tau;          //vector of random effects variance for each model parameter
cholesky_factor_corr[Nparameters] L_Omega; //lower triangle of a correlation matrix to be used for the random effect of the model parameters

//subject level parameters
vector[Nparameters] auxiliary_parameters[Nsubjects];
}

transformed parameters {

//population level
matrix[Nparameters,Nparameters] sigma_matrix;

//individuals level
real alpha[Nsubjects];
real beta[Nsubjects];

real  log_lik[Nsubjects,Ntrials];
vector<lower=0, upper=1>[Narms] Qcard;

//pre-assignment
sigma_matrix = diag_pre_multiply(tau, (L_Omega*L_Omega')); //L_Omega*L_omega' give us Omega (the corr matrix).
sigma_matrix = diag_post_multiply(sigma_matrix, tau);     // diag(tau)*omega*diag(tau) gives us sigma_matirx (the cov matrix)

log_lik=rep_array(0,Nsubjects,Ntrials);

///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
for (subject in 1:Nsubjects){

alpha[subject]              = inv_logit(auxiliary_parameters[subject]);
beta[subject]               = exp(auxiliary_parameters[subject]);

//pre-assignment of Qvalues
for (a in 1:Narms)   Qcard[a]    = 0;

//trial by trial loop
for (trial in 1:Ntrials_per_subject[subject]){

//liklihood function (softmax)
log_lik[subject,trial]=log_softmax(Qcard*beta[subject])[action[subject,trial]];

//Qvalues update
Qcard[action[subject,trial]] += alpha[subject] * (reward[subject,trial] - Qcard[action[subject,trial]]);

}
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
}
model {

// population level priors (hyper-parameters)
mu  ~ normal(0, 5);             // mu is a vector 1xNparameters with the population mean (i.e., location) for each model parameter
tau ~ cauchy(0, 1);             //tau is the hyperparameters variance vector
L_Omega ~ lkj_corr_cholesky(2); //L_omega is the lower triangle of the correlations. Setting the lkj prior to 2 means the off-diagonals are priored to be near zero

// indvidual level priors (subject parameters)
auxiliary_parameters ~ multi_normal(mu, sigma_matrix);

target += sum(log_lik);

}

``````
1 Like

Hi,
when a model is slow, the first thing to check is whether zou are not seeing any problems with the model (bad fit to data, bad diagnostics, weird structures on the `pairs` plot, …), as models that behave badly are often very slow.

That’s a good first check, but it is often useful to also check coverage - e.g. do x% posterior intervals include the true values roughly x% of the time for various x? (this is then related to SBC - simulation based calibration).

If there are no problems with the model itself the case, there is a bunch of stuff you can try to gain speed-ups the “hard way”:

• since you seem to already have the cholesky decomposition of `sigma_matrix`, using `multi_normal_cholesky` is preferred (`multi_normal` will decompose the matrix internally)
• You seem to be calculating `sigma_matrix` twice in the code
• You could use non-centered parametrization for `auxiliary_parameters` with something like (I didn’t test the code, but I hope the general idea is clear):
``````parameters {
matrix[Nsubjects, Nparameters] z;
}

transformed parameters {
auxiliary_parameters = diag_pre_multiply(tau, L_Omega) * z;
}

model {
to_vector(z) ~ std_normal();
}

``````
• Since the calculation decomposes neatly by participant, you may benefit from using `reduce_sum` and use multiple cores per chain.
• The `log_lik` matrix is likely quite big so it is possible storing it is expensive. Moving the calculation to `model` block and incrementing `target` instead of storing could help.
• The `cauchy(0,1)` prior is likely more heavy-tailed than you want - maybe there is domain knowledge to restrict the a-priori between-subject variances a bit more?
• In `log_softmax(Qcard*beta[subject])[action[subject,trial]]` you compute the whole softmax vector but then use just one element of it. A minor speedup is likely possible with (please double check I didn’t mess up the simplification):
``````vector[narms] Qscaled = Qcard*beta[subject];
log_lik[subject,trial] = Qscaled[action[subject,trial]] - log_sum_exp(Qscaled);
``````

In all cases, it is useful to time your model a few times to see if you are actually making progress. Also, since 2.26, Stan supports profiling, so you can see how costly are individual sections of your code. It definitely makes no sense to optimize the parts that are already fast: see e.g. Profiling Stan programs with CmdStanR • cmdstanr

If you believe the posterior is well approximated by a multivariate normal, you may also try using the built-in ADVI, but you should definitely check if the results don’t change much between ADVI and MCMC for datasets where MCMC is feasible first.

Hope at least some of this will help.

1 Like

Martin,

This is extremely helpful - thank you for taking the time.

Since the calculation decomposes neatly by participant, you may benefit from using `reduce_sum` and use multiple cores per chain.

How would you go about on implementing something like this? If I understand correctly, reduce_sum will automatically partition the data (which might be problematic here, since every subject needs to be estimated across all trials). Do you think it would make sense to try " `reduce_sum_static`" or did you have a different way in mind?

I promise to update, and I can say that already moving the log_lik made the process about twice faster.

Thanks,
Nitzan.

I am on a cellphone, so cannot Google very well, but there was a very nice tutorial for reduce_sum,I think by Sebastian Weber (@wds15 ) that should cover the basics. The main point is to find a way to split your data and parameters in blocks that don’t have big overlaps, to avoid the need to transfer the same data/parameters to multiple threads. In your case, it appears that the log likelihood per participant can be computed independently of other participants and should thus naturally fit the reduce_sum paradigm.

If you are stuck finding the instructions or following them, feel free to ask for clarifications.

1 Like

To follow up in case other users will want to work with RL Q learning code in stan - hBayesDM has really fantastic implementations and I have personally learned a lot from reading into their RL code

2 Likes