We are interested in fitting a learning model in a motor adaptation experiment. This is a state-space model that assumes two states contribute to learning: a ‘fast’ process that responds strongly to error (i.e. high learning rate) but retains information poorly (i.e. low retention), and a ‘slow’ process that responds weakly to error (i.e. low learning rate) but retains information well (i.e. high retention). Thus, for each state there are two parameters, retention factor A and learning rate B.
We verified that this model converges nicely for a single subject’s dataset (192 observations). However, we want to model these parameters hierarchically for a large group of subjects (N > 100), but unfortunately convergence problems arise.
One complication here is that we need to ensure, on a subject-by-subject basis, that the learning rate is greater for the ‘fast’ process than the ‘slow’ process, and the retention factor is greater for the ‘slow’ process than the ‘fast’ process. Therefore, we are not modeling the ‘fast’ learning rate and ‘slow’ retention factor directly, but instead as a non-negative offset to the ‘slow’ learning rate and ‘fast’ retention factor respectively.
The basic model is as follows (full Stan code at the bottom):
for (t in 2:T) {
// 'fast' state estimate
x_f[t] = a_f * x_f[t-1] + b_f * error[t-1];
// 'slow' state estimate
x_s[t] = a_s * x_s[t-1] + b_s * error[t-1];
// combined state estimate
x[t] = x_f[t] + x_s[t]
// trajectory error
error[t] = perturbation[t] - x[t];
// likelihood statement
y[t] ~ normal(error[t], sigma_e);
}
The hierarchical model (see below) is failing to converge: There are many divergent transitions, large R-hat values, very small effective sample sizes, and the traceplots look flat with virtually no mixing between chains. We are already using the non-centred parameterisation, and we are using informative priors for the group-level means (based on a previously published study using the same model).
We would greatly appreciate any advice for improving our model.
data {
int<lower=0> N; // total number of data points
int<lower=0> J; // total number of subjects
int<lower=0, upper=J> subj[N]; // subject indicator for each trial
real y[N]; // observed trajectory error for each trial
int p[N]; // experimental manipulation (visual perturbation) for each trial
int<lower=0, upper=1> first_trial[N]; // logical vector: is this the first trial for this subject?
}
parameters {
// group level mean parameters
real mu_a_f; // retention factor for fast process
real<lower=0> mu_a_s_shift; // retention factor for slow process, modeled as positive offset to mu_a_f
real mu_b_s; // learning rate for slow process
real<lower=0> mu_b_f_shift; // learning rate for fast process, modeled as positive offset to mu_b_s
// group level variance parameters
real<lower=0> sigma_a_f;
real<lower=0> sigma_a_s_shift;
real<lower=0> sigma_b_s;
real<lower=0> sigma_b_f_shift;
// individual subject level parameters
// non-centered parameterisation: 'raw' subject parameters will be scaled and shifted by group-level
real a_f_raw[J];
real a_s_shift_raw[J];
real b_s_raw[J];
real b_f_shift_raw[J];
real<lower=0> sigma_e[J]; // residual SD
}
transformed parameters {
real a_f[J];
real<lower=0> a_s_shift[J];
real b_s[J];
real<lower=0> b_f_shift[J];
// final subject-level parameters, scaled to 0-1 range
real<lower=0, upper=1> a_f_scaled[J];
real<lower=0, upper=1> a_s_scaled[J];
real<lower=0, upper=1> b_s_scaled[J];
real<lower=0, upper=1> b_f_scaled[J];
for (j in 1:J) {
// implies: a_f ~ normal(mu_a_f, sigma_a_f)
a_f[j] = mu_a_f + sigma_a_f * a_f_raw[j];
// implies: a_s_shift ~ normal(mu_a_s_shift, sigma_a_s_shift)
a_s_shift[j] = mu_a_s_shift + sigma_a_s_shift * a_s_shift_raw[j];
// implies: b_s ~ normal(mu_b_s, sigma_b_s)
b_s[j] = mu_b_s + sigma_b_s * b_s_raw[j];
// implies: b_f_shift ~ normal(mu_b_f_shift, sigma_b_f_shift)
b_f_shift[j] = mu_b_f_shift + sigma_b_f_shift * b_f_shift_raw[j];
// transform to 0-1 range using the standard normal cumulative distribution function
a_f_scaled[j] = Phi_approx(a_f[j]);
a_s_scaled[j] = Phi_approx(a_f[j] + a_s_shift[j]);
b_s_scaled[j] = Phi_approx(b_s[j]);
b_f_scaled[j] = Phi_approx(b_s[j] + b_f_shift[j]);
}
}
model {
// initialise local variables
real x_f[N]; // state estimate for fast process
real x_s[N]; // state estimate for slow process
real x[N]; // combined state estimate
real err[N]; // expected value of trajectory error
// set values for first trial to zero
for (n in 1:N){
if(first_trial[n] == 1){
x_f[n] = 0;
x_s[n] = 0;
x[n] = 0;
err[n] = 0;
}
}
// priors for group-level mean parameters
// means based on Trewartha et al. (2014) Journal of Neuroscience
mu_a_f ~ normal(0, 0.5); // Phi_approx(0) = 0.5
mu_a_s_shift ~ normal(1.3, 0.5); // Phi_approx(0 + 1.3) = 0.9
mu_b_s ~ normal(-1.3, 0.5); // Phi_approx(-1.3) = 0.1
mu_b_f_shift ~ normal(1.05, 0.5); // Phi_approx(-1.3 + 1.05) = 0.4
// priors for group-level variance parameters
sigma_a_f ~ cauchy(0, 1);
sigma_a_s_shift ~ cauchy(0, 1);
sigma_b_s ~ cauchy(0, 1);
sigma_b_f_shift ~ cauchy(0, 1);
// standard normal priors for 'raw' subject-level parameters
a_f_raw ~ normal(0, 1);
a_s_shift_raw ~ normal(0, 1);
b_s_raw ~ normal(0, 1);
b_f_shift_raw ~ normal(0, 1);
sigma_e ~ cauchy(0, 5);
// model the observed trajectory errors using the parameters
for (n in 1:N){
if(first_trial[n] == 0){
x_f[n] = a_f_scaled[subj[n]] * x_f[n - 1] + b_f_scaled[subj[n]] * err[n - 1];
x_s[n] = a_s_scaled[subj[n]] * x_s[n - 1] + b_s_scaled[subj[n]] * err[n - 1];
x[n] = x_f[n] + x_s[n];
err[n] = p[n] - x[n];
// likelihood statement
y[n] ~ normal(err[n], sigma_e[subj[n]]);
}
}
}