Model takes many hours to fit and chains don't converge

Hi! I am new to stan and I’m trying to fit a Q-learning model for a group of 29 subjects, each of whom executed 5 sessions of 60 trials. In each trial, subjects were presented with one of three states (st={1,2,3}), each having two choices available (c={0,1}), and upon making a choice, they either received a reward or not (r={0,1}). I’m able to fit the simplest version of this model, a standard Q-learning model with one learning rate per subject (stan code not shown) quickly and without difficulties. However, I run into trouble when I try to fit a slightly more complex model (stan code below) with three learning rates (alphaD, alphaHG, alphaLG).

I am using macOS Catalina 10.15.3 and python/pystan. Here’s the stan code I’m using.

functions {
  real softsign(real Q) {
    return exp(10 * Q) / (1 + exp(10 * Q));
  }
}
data {
  int NS;
  int NSS;
  int NT;
  int st[NS,NSS,NT]; // {1,2,3}
  int c[NS,NSS,NT];  // {0,1}
  int r[NS,NSS,NT];  // {0,1}
}
parameters {
  real betam;
  real alphaDm;
  real alphaHGm;
  real alphaLGm;
  real<lower=0> betasd;
  real<lower=0> alphaDsd;
  real<lower=0> alphaHGsd;
  real<lower=0> alphaLGsd;
  real betas[NS];
  real alphaDs[NS];
  real alphaHGs[NS];
  real alphaLGs[NS];
}
model {
  betam ~ normal(0,1);
  alphaDm ~ normal(0,1);
  alphaHGm ~ normal(0,1);
  alphaLGm ~ normal(0,1);
  betasd ~ normal(0,1);
  alphaDsd ~ normal(0,1);
  alphaHGsd ~ normal(0,1);
  alphaLGsd ~ normal(0,1);
  for (s in 1:NS) { // Loop over subjects
    real alphaD;
    real alphaHG;
    real alphaLG;
    real Q[3, 2];
    for (i in 1:3) {
      Q[i,1] = 0; Q[i,2] = 0; // Initialize Q-values for this subject with zero
    }
    betas[s] ~ normal(betam,betasd);
    alphaDs[s] ~ normal(alphaDm,alphaDsd);
    alphaHGs[s] ~ normal(alphaHGm,alphaHGsd);
    alphaLGs[s] ~ normal(alphaLGm,alphaLGsd);
    alphaD = Phi_approx(alphaDs[s]);
    alphaHG = Phi_approx(alphaHGs[s]);
    alphaLG = Phi_approx(alphaLGs[s]);
    for (ss in 1:NSS) { // Loop over sessions
      for (t in 1:NT) { // Loop over trials
        if (c[s,ss,t] >= 0) { // Only fit if subject entered response
          // Choice (softmax)
          c[s,ss,t] ~ bernoulli_logit(betas[s] * (Q[st[s,ss,t], 2] - Q[st[s,ss,t], 1]));
          // Q-learning
          if (st[s,ss,t] == 1) {
            Q[1, c[s,ss,t]+1] += alphaD * ( r[s,ss,t] - Q[1, c[s,ss,t]+1] );
            if (r[s,ss,t] == 1) {
              Q[3, c[s,ss,t]+1] += softsign(Q[3,c[s,ss,t]+1]-Q[2,c[s,ss,t]+1])*(alphaLG*(r[s,ss,t]-Q[3,c[s,ss,t]+1]));
              Q[3, c[s,ss,t]+1] += softsign(Q[2,c[s,ss,t]+1]-Q[3,c[s,ss,t]+1])*(alphaHG*(r[s,ss,t]-Q[3,c[s,ss,t]+1]));
              Q[2, c[s,ss,t]+1] += softsign(Q[3,c[s,ss,t]+1]-Q[2,c[s,ss,t]+1])*(alphaHG*(r[s,ss,t]-Q[2,c[s,ss,t]+1]));
              Q[2, c[s,ss,t]+1] += softsign(Q[2,c[s,ss,t]+1]-Q[3,c[s,ss,t]+1])*(alphaLG*(r[s,ss,t]-Q[2,c[s,ss,t]+1]));
            }
            else if (r[s,ss,t]==0) {
              Q[3, c[s,ss,t]+1] += softsign(Q[3,c[s,ss,t]+1]-Q[2,c[s,ss,t]+1])*(alphaHG*(r[s,ss,t]-Q[3,c[s,ss,t]+1]));
              Q[3, c[s,ss,t]+1] += softsign(Q[2,c[s,ss,t]+1]-Q[3,c[s,ss,t]+1])*(alphaLG*(r[s,ss,t]-Q[3,c[s,ss,t]+1]));
              Q[2, c[s,ss,t]+1] += softsign(Q[3,c[s,ss,t]+1]-Q[2,c[s,ss,t]+1])*(alphaLG*(r[s,ss,t]-Q[2,c[s,ss,t]+1]));
              Q[2, c[s,ss,t]+1] += softsign(Q[2,c[s,ss,t]+1]-Q[3,c[s,ss,t]+1])*(alphaHG*(r[s,ss,t]-Q[2,c[s,ss,t]+1]));
            }
          }
        }
      }
    }
  }
}

Here’s the type of data that I’m trying to fit, as well as the command that I use to fit the model.

stan_data = {
 'NS': 29, // number of subjects
 'NSS': 5, // number of sessions
 'NT': 60, // number of trials
 'st': array([[[3, 3, 1, ..., 3, 3, 2], // state at trial t
         [1, 3, 2, ..., 2, 1, 3],
         [3, 2, 2, ..., 2, 3, 2],
         [3, 2, 2, ..., 3, 2, 3],
         [3, 2, 3, ..., 1, 3, 3]],
         [1, 3, 3, ..., 2, 3, 3]],
 
        ...,
 
        [[3, 3, 2, ..., 2, 3, 3],
         [3, 3, 2, ..., 3, 2, 3],
         [1, 3, 2, ..., 3, 3, 2],
         [3, 3, 2, ..., 3, 2, 1],
         [2, 3, 1, ..., 1, 1, 3]]]),
 'c': array([[[ 0,  0,  0, ...,  0,  1,  1], // choice at trial t
         [ 1,  1,  1, ...,  1,  0,  0],
         [ 0,  0,  0, ...,  1,  0,  1],
         [ 0,  0,  0, ...,  0,  0,  1],
         [ 0,  0,  0, ...,  1,  1,  1]],
 
        ...,
 
        [[ 0,  0, -1, ...,  0,  1,  0],
         [-1,  0,  0, ...,  0,  0,  0],
         [-1,  0,  1, ...,  0,  1,  0],
         [ 1,  1,  1, ...,  0,  0,  1],
         [ 1,  1,  1, ...,  1,  0,  1]]]),
 'r': array([[[ 1,  1,  1, ...,  0,  0,  1], // reward at trial t
 
        [[ 0,  0,  0, ...,  1,  0,  0],
         [ 1, -1,  0, ...,  1,  0,  1],
         [ 0,  0,  1, ...,  1,  1,  0],
         [ 1,  1,  1, ...,  0,  1,  1],
         [ 0, -1,  1, ...,  1,  0,  0]],
 
        ...,
 
        [[ 0,  1, -1, ...,  1,  0,  1],
         [-1,  1,  0, ...,  1,  1,  1],
         [-1,  0,  1, ...,  0,  1,  1],
         [ 0,  1,  0, ...,  1,  1,  0],
         [ 0,  0,  1, ...,  0,  1,  1]]])}


# Fit data and print results
fit_replayG = sm_replayG.sampling(data=stan_data, iter=1000, warmup=250, chains=4)
print(fit_replayG)

However, fitting this model takes about 6 hours (as opposed to about 5 minutes for the slightly simpler model I was fitting), and the results don’t converge (most Rhat are greater than 2).

In addition to the convergence problems, a warning I sometimes receive while fitting is:

Exception: normal_lpdf: Scale parameter is 0, but must be > 0!  (in 'replayG.stan' at line 45)

Hi, just to give some generic advice here, I suggest simulating fake data from your model and then fitting the model and seeing if you can recover the parameters. Since it’s taking a long time to run, I suggest just running your 4 parallel chains for 100 warmup and 100 saved iterations and set max treedepth to 5. Just to get things started, cos you don’t want to be waiting for hours every time you debug the model. That’s like what it was like when I took a computer science class in 1977 and we had to write our code on punch cards and then wait hours for it to get run through the computer.

6 Likes

When you do the simulations, keep your variances small at first.
Try to generate data where you can tell easily that you generated correctly from the model, even if that means combining copies of the same data over and over with small bits of noise added to each.
As Tom Tang likes to say, keep it simple AND stupid.

1 Like