Hierarchical learning model fitting issues

We are interested in fitting a learning model in a motor adaptation experiment. This is a state-space model that assumes two states contribute to learning: a ‘fast’ process that responds strongly to error (i.e. high learning rate) but retains information poorly (i.e. low retention), and a ‘slow’ process that responds weakly to error (i.e. low learning rate) but retains information well (i.e. high retention). Thus, for each state there are two parameters, retention factor A and learning rate B.

We verified that this model converges nicely for a single subject’s dataset (192 observations). However, we want to model these parameters hierarchically for a large group of subjects (N > 100), but unfortunately convergence problems arise.

One complication here is that we need to ensure, on a subject-by-subject basis, that the learning rate is greater for the ‘fast’ process than the ‘slow’ process, and the retention factor is greater for the ‘slow’ process than the ‘fast’ process. Therefore, we are not modeling the ‘fast’ learning rate and ‘slow’ retention factor directly, but instead as a non-negative offset to the ‘slow’ learning rate and ‘fast’ retention factor respectively.

The basic model is as follows (full Stan code at the bottom):

for (t in 2:T) {
    // 'fast' state estimate
    x_f[t] = a_f * x_f[t-1] + b_f * error[t-1];
    // 'slow' state estimate
    x_s[t] = a_s * x_s[t-1] + b_s * error[t-1];
    // combined state estimate
    x[t] = x_f[t] + x_s[t]
    // trajectory error
    error[t] = perturbation[t] - x[t];
    // likelihood statement
    y[t] ~ normal(error[t], sigma_e);
}

The hierarchical model (see below) is failing to converge: There are many divergent transitions, large R-hat values, very small effective sample sizes, and the traceplots look flat with virtually no mixing between chains. We are already using the non-centred parameterisation, and we are using informative priors for the group-level means (based on a previously published study using the same model).

We would greatly appreciate any advice for improving our model.

data {
  int<lower=0> N; // total number of data points
  int<lower=0> J; // total number of subjects
  int<lower=0, upper=J> subj[N]; // subject indicator for each trial
  real y[N]; // observed trajectory error for each trial
  int p[N]; // experimental manipulation (visual perturbation) for each trial
  int<lower=0, upper=1> first_trial[N]; // logical vector: is this the first trial for this subject?
}
parameters {
  // group level mean parameters
  real mu_a_f; // retention factor for fast process
  real<lower=0> mu_a_s_shift; // retention factor for slow process, modeled as positive offset to mu_a_f
  real mu_b_s; // learning rate for slow process
  real<lower=0> mu_b_f_shift; // learning rate for fast process, modeled as positive offset to mu_b_s
  // group level variance parameters
  real<lower=0> sigma_a_f;
  real<lower=0> sigma_a_s_shift;
  real<lower=0> sigma_b_s;
  real<lower=0> sigma_b_f_shift;
  // individual subject level parameters
  // non-centered parameterisation: 'raw' subject parameters will be scaled and shifted by group-level
  real a_f_raw[J];
  real a_s_shift_raw[J];
  real b_s_raw[J];
  real b_f_shift_raw[J];
  real<lower=0> sigma_e[J]; // residual SD
}
transformed parameters {
  real a_f[J];
  real<lower=0> a_s_shift[J];
  real b_s[J];
  real<lower=0> b_f_shift[J];
  // final subject-level parameters, scaled to 0-1 range
  real<lower=0, upper=1> a_f_scaled[J];
  real<lower=0, upper=1> a_s_scaled[J];
  real<lower=0, upper=1> b_s_scaled[J];
  real<lower=0, upper=1> b_f_scaled[J];
  for (j in 1:J) {
    // implies: a_f ~ normal(mu_a_f, sigma_a_f)
    a_f[j] = mu_a_f + sigma_a_f * a_f_raw[j];
    // implies: a_s_shift ~ normal(mu_a_s_shift, sigma_a_s_shift)
    a_s_shift[j] = mu_a_s_shift + sigma_a_s_shift * a_s_shift_raw[j];
    // implies: b_s ~ normal(mu_b_s, sigma_b_s)
    b_s[j] = mu_b_s + sigma_b_s * b_s_raw[j];
    // implies: b_f_shift ~ normal(mu_b_f_shift, sigma_b_f_shift)
    b_f_shift[j] = mu_b_f_shift + sigma_b_f_shift * b_f_shift_raw[j];
    // transform to 0-1 range using the standard normal cumulative distribution function
    a_f_scaled[j] = Phi_approx(a_f[j]);
    a_s_scaled[j] = Phi_approx(a_f[j] + a_s_shift[j]);
    b_s_scaled[j] = Phi_approx(b_s[j]);
    b_f_scaled[j] = Phi_approx(b_s[j] + b_f_shift[j]);
  }
}
model {
  // initialise local variables
  real x_f[N]; // state estimate for fast process
  real x_s[N]; // state estimate for slow process
  real x[N]; // combined state estimate
  real err[N]; // expected value of trajectory error
  // set values for first trial to zero
  for (n in 1:N){
    if(first_trial[n] == 1){
      x_f[n] = 0;
      x_s[n] = 0;
      x[n] = 0;
      err[n] = 0;
    }
  }
  // priors for group-level mean parameters
  // means based on Trewartha et al. (2014) Journal of Neuroscience
  mu_a_f ~ normal(0, 0.5); // Phi_approx(0) = 0.5
  mu_a_s_shift ~ normal(1.3, 0.5); // Phi_approx(0 + 1.3) = 0.9 
  mu_b_s ~ normal(-1.3, 0.5); // Phi_approx(-1.3) = 0.1
  mu_b_f_shift ~ normal(1.05, 0.5); // Phi_approx(-1.3 + 1.05) = 0.4
  // priors for group-level variance parameters
  sigma_a_f ~ cauchy(0, 1);
  sigma_a_s_shift ~ cauchy(0, 1);
  sigma_b_s ~ cauchy(0, 1);
  sigma_b_f_shift ~ cauchy(0, 1);
  // standard normal priors for 'raw' subject-level parameters
  a_f_raw ~ normal(0, 1);
  a_s_shift_raw ~ normal(0, 1);
  b_s_raw ~ normal(0, 1);
  b_f_shift_raw ~ normal(0, 1);
  sigma_e ~ cauchy(0, 5);
  // model the observed trajectory errors using the parameters
  for (n in 1:N){
    if(first_trial[n] == 0){
      x_f[n] = a_f_scaled[subj[n]] * x_f[n - 1] + b_f_scaled[subj[n]] * err[n - 1];
      x_s[n] = a_s_scaled[subj[n]] * x_s[n - 1] + b_s_scaled[subj[n]] * err[n - 1];
      x[n] = x_f[n] + x_s[n];
      err[n] = p[n] - x[n];
      // likelihood statement
      y[n] ~ normal(err[n], sigma_e[subj[n]]);
    }
  }
}
1 Like

Nowadays there’s an easier way to write noncentered parametrization, the affine transform.

parameters {
  real mu_a_f;
  real<lower=0> sigma_a_f;
  real<offset=mu_a_f,multiplier=sigma_a_f> a_f;
}
model {
  a_f ~ normal(mu_a_f, sigma_a_f);
}

Those cauchy priors are very wide. Since you expect to see hierarchical structure (rather than a_f_scaled all being either zero or one at random) the prior for a_f_sigma should constrain it a lot more, maybe exponential(5) or something.

5 Likes

Thanks for the tip about non-centered parameterisation! And yes I can see how having tighter priors on the group-level variance parameters might help here. I actually used even wider priors previously (something like cauchy(0, 5)), but after doing some prior predictive checks realised that wasn’t sensible.

After implementing your suggestions @nhuurre, I can confirm that there are no longer convergence issues! However, a different fitting issue has now popped up.

If I understand correctly, when using the affine transform to specify non-centered parameterisation, you cannot additionally specify hard lower and upper bounds. So for example, the following code would yield an error, because of the lower=0 for a_f:

parameters {
  real mu_a_f;
  real<lower=0> sigma_a_f;
  real<offset=mu_a_f,multiplier=sigma_a_f, lower=0> a_f;
}
model {
  a_f ~ normal(mu_a_f, sigma_a_f);
}

My problem is that I have subject-level parameters, a_s_shift and b_f_shift, that are meant to be non-negative offsets to other parameters. For a_s_shift the estimates look very reasonable, but for b_f_shift the estimates are centered on zero for many subjects, or even slightly negative for a handful of subjects. The posterior of the group-level mean mu_b_f_shift, which was constrained to be non-negative, indeed has a lot of mass close to zero.

Should I re-write my Stan code so that the subject-level b_f_shift parameters are constrained to be non-negative as well (e.g. in the transformed parameters block)? Or should I take this result as an indication that for the B parameter, there isn’t much differentiation between the ‘slow’ process b_s and the ‘fast’ process b_f?

Noncentering constrained parameters is tricky, even if you don’t use offset/multiplier syntax.
b_f_shift centered on zero is a bit concerning because the model looks like it could get b_f and b_s mixed up. Maybe try fitting simulated data to make sure the model can recover nonzero b_f_shift when that is present?

2 Likes