(Apologies for the long post)
Goal: to model changes in human speech perception based on recently experienced speech input. More specifically, the program I posted below aims to infer listeners’ prior beliefs about the cue distributions (means and covariance matrices) corresponding to 2 or more multivariate Gaussian categories (each corresponding to a sound category, e.g., /p/ vs. /b/). This is done under the assumption that listeners a) start with a set of prior beliefs about the mean and covariance matrices of the categories (modeled as Normal-Inverse-Wishart priors; Murphy 2012) and b) update those beliefs based on the sufficient statistics of recently experienced input (‘exposure’).
What I need help with: The program seems to work but—curious to me—when I provide additional optional information to the program (e.g., when I tell the model where the category means are, so that it only need to infer the category covariance matrices), I mostly get divergent transitions. This is the case even though the additional information is ‘correct’ , i.e., when I use data that match the model assumptions and for which I know the ground truth. The few remaining posterior samples ‘make sense’ but I’m trying to understand whether there’s something I’m missing in the way I’m handling optional parameters in this program that causes those divergent transitions. It seems stan is still sampling from those parameters even when they are user-provided. Apologies if this is a naive question.
The input to the program is a combination of:
- sufficient statistics (number of observations, k-dimensional mean, and kxk sum-of-square matrix) for each category in each exposure conditions
- counts of subsequent categorization responses (category 1, 2, …) on a series of test tokens (k-dimensional vectors).
- Optionally, users can provide information about the prior mean and/or covariance matrix.
The model aims to infer the shared prior beliefs that explain those test responses given the exposure statistics (under lots of simplifying assumptions). Here’s a screen shot of belief-updating the model essentially inverses in order to infer the prior beliefs (Murphy, 2012: p 134), where \mu and \Sigma are the category mean and covariance matrix, and m, S, \kappa, and \nu are the parameters of the Normal-Inverse-Wishart model:
Here is the stan code:
data {
int M; // number of categories
int L; // number of grouping levels (e.g. subjects)
int K; // number of features
matrix[M,L] N; // number of observations per category (m) and group (l)
vector[K] x_mean[M,L]; // means for each category (m) and group (l)
cov_matrix[K] x_ss[M,L]; // sum of uncentered squares matrix for each category (m) and group (l)
int N_test; // number of test trials
vector[K] x_test[N_test]; // locations of test trials
int y_test[N_test]; // group label of test trials
int z_test_counts[N_test,M]; // responses for test trials
int<lower=0, upper=1> m_0_known;
int<lower=0, upper=1> S_0_known;
vector[m_0_known ? K : 0] m_0_data[m_0_known ? M : 0]; // optional: user provided m_0 (prior mean of means)
cov_matrix[S_0_known ? K : 0] S_0_data[S_0_known ? M : 0]; // optional: user provided S_0 (prior scatter matrix of mean)
real<lower=0> tau_scale; // scale of cauchy prior for variances of m_0 (set to zero to ignore)
real<lower=0> L_omega_scale; // scale of LKJ prior for correlation of variance of m_0 (set to zero to ignore)
}
transformed data {
real sigma_kappanu;
/* Scale for the prior of kappa/nu_0. In order to deal with input that does not contain observations
(in which case n_each == 0), we set the minimum value for SD to 10. */
sigma_kappanu = max(N) > 0 ? max(N) * 4 : 10;
}
parameters {
// these are all shared across groups (same prior beliefs):
real<lower=K> kappa_0; // prior pseudocount for category mu
real<lower=K + 1> nu_0; // prior pseudocount for category Sigma
vector[K] m_0_param[m_0_known ? 0 : M]; // prior mean of means
vector<lower=0>[K] m_0_tau; // prior variances of m_0
cholesky_factor_corr[K] m_0_L_omega; // prior correlations of variances of m_0 (in cholesky form)
vector<lower=0>[K] tau_0_param[S_0_known ? 0 : M]; // standard deviations of prior scatter matrix S_0
cholesky_factor_corr[K] L_omega_0_param[S_0_known ? 0 : M]; // correlation matrix of prior scatter matrix S_0 (in cholesky form)
real<lower=0, upper=1> lapse_rate;
}
transformed parameters {
vector[K] m_0[M]; // prior mean of means m_0
cov_matrix[K] S_0[M]; // prior scatter matrix S_0
// updated beliefs depend on input and group
real<lower=K> kappa_n[M,L]; // updated mean pseudocount
real<lower=K> nu_n[M,L]; // updated sd pseudocount
vector[K] m_n[M,L]; // updated expected mean
cov_matrix[K] S_n[M,L]; // updated expected scatter matrix
cov_matrix[K] t_scale[M,L]; // scale matrix of predictive t distribution
simplex[M] p_test_conj[N_test];
vector[M] log_p_test_conj[N_test];
if (m_0_known) {
m_0 = m_0_data;
} else {
m_0 = m_0_param;
}
if (S_0_known) {
S_0 = S_0_data;
}
// update NIW parameters according to conjugate updating rules are taken from
// Murphy (2007, p. 136)
for (cat in 1:M) {
if (!S_0_known) {
// Get S_0 from its components: correlation matrix and vector of standard deviations
S_0[cat] = quad_form_diag(multiply_lower_tri_self_transpose(L_omega_0_param[cat]), tau_0_param[cat]);
}
for (group in 1:L) {
if (N[cat,group] > 0 ) {
kappa_n[cat,group] = kappa_0 + N[cat,group];
nu_n[cat,group] = nu_0 + N[cat,group];
m_n[cat,group] = (kappa_0 * m_0[cat] + N[cat,group] * x_mean[cat,group]) /
kappa_n[cat,group];
S_n[cat,group] = S_0[cat] +
x_ss[cat,group] +
kappa_0 * m_0[cat] * m_0[cat]' -
kappa_n[cat,group] * m_n[cat,group] * m_n[cat,group]';
} else {
kappa_n[cat,group] = kappa_0;
nu_n[cat,group] = nu_0;
m_n[cat,group] = m_0[cat];
S_n[cat,group] = S_0[cat];
}
t_scale[cat,group] = S_n[cat,group] * (kappa_n[cat,group] + 1) /
(kappa_n[cat,group] * (nu_n[cat,group] - K + 1));
}
}
// compute category probabilities for each of the test stimuli
for (j in 1:N_test) {
int group;
group = y_test[j];
// calculate un-normalized log prob for each category
for (cat in 1:M) {
log_p_test_conj[j,cat] = multi_student_t_lpdf(x_test[j] |
nu_n[cat,group] - K + 1,
m_n[cat,group],
t_scale[cat,group]);
}
// normalize and store actual probs in simplex
p_test_conj[j] = exp(log_p_test_conj[j] - log_sum_exp(log_p_test_conj[j]));
}
}
model {
vector[M] lapsing_probs;
lapsing_probs = rep_vector(lapse_rate / M, M);
kappa_0 ~ normal(0, sigma_kappanu);
nu_0 ~ normal(0, sigma_kappanu);
/* Specifying prior for m_0:
- If no scale for variances (tau) of m_0 is user-specified use weakly regularizing
scale (5) for variances of mean.
- If no scale for LKJ prior over correlation matrix of m_0 is user-specified use
scale 1 to set uniform prior over correlation matrices. */
if (!m_0_known) {
m_0_tau ~ cauchy(0, tau_scale > 0 ? tau_scale : 5);
m_0_L_omega ~ lkj_corr_cholesky(L_omega_scale > 0 ? L_omega_scale : 1);
m_0_param ~ multi_normal_cholesky(rep_vector(0, K), diag_pre_multiply(m_0_tau, m_0_L_omega));
}
/* Specifying prior for components of S_0: */
if (!S_0_known) {
for (cat in 1:M) {
tau_0_param[cat] ~ cauchy(0, tau_scale > 0 ? tau_scale : 10);
L_omega_0_param[cat] ~ lkj_corr_cholesky(L_omega_scale > 0 ? L_omega_scale : 1);
}
}
for (i in 1:N_test) {
z_test_counts[i] ~ multinomial(p_test_conj[i] * (1-lapse_rate) + lapsing_probs);
}
}
generated quantities {
if (!m_0_known) {
matrix[K,K] m_0_cor;
matrix[K,K] m_0_cov;
m_0_cor = multiply_lower_tri_self_transpose(m_0_L_omega);
m_0_cov = quad_form_diag(m_0_cor, m_0_tau);
}
}
And here is an example input (with known ground truth; these data were generated in a way that meets all the assumptions of the model):
example-input.RData (5.3 KB)
If I run this model, it seems to converge (with large uncertainty about the parameter, which makes sense for this input). My question is about what happens when I also provide the model with m_0 (as in this alternative example input: example-input-with-m0.RData). For this example, I change m_0_known to 1 and changed m_0_data to the correct mean (the prior mean the data was generated from):
[,1] [,2]
[1,] -1.0817 -0.06806
[2,] 0.0769 0.05767
Then the model fit results in lots of divergent transitions. The output still makes sense (approximates the ground truth the data was generated from; e.g., the data was generated from kappa = nu = 4) but there are very few samples in the model:
Inference for Stan model: mvg_conj_sufficient_stats_lapse.
4 chains, each with iter=4000; warmup=2000; thin=1;
post-warmup draws per chain=2000, total post-warmup draws=8000.
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
kappa_0 5.120e+00 0.09 0.74 3.870e+00 4.600e+00 5.030e+00 5.600e+00 6.710e+00 66 1.04
nu_0 3.232e+01 3.95 27.41 3.110e+00 1.301e+01 2.536e+01 4.345e+01 1.043e+02 48 1.07
m_0_tau[1] 9.141e+307 NaN Inf 4.077e+306 4.374e+307 9.291e+307 1.374e+308 1.763e+308 NaN NaN
m_0_tau[2] 8.914e+307 NaN Inf 3.342e+306 4.641e+307 8.902e+307 1.324e+308 1.755e+308 NaN NaN
m_0_L_omega[1,1] 1.000e+00 NaN 0.00 1.000e+00 1.000e+00 1.000e+00 1.000e+00 1.000e+00 NaN NaN
m_0_L_omega[1,2] 0.000e+00 NaN 0.00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 NaN NaN
m_0_L_omega[2,1] 1.300e-01 0.06 0.58 -9.200e-01 -3.700e-01 2.000e-01 6.300e-01 9.900e-01 91 1.03
m_0_L_omega[2,2] 7.700e-01 0.02 0.24 1.600e-01 6.200e-01 8.600e-01 9.600e-01 1.000e+00 99 1.02
tau_0_param[1,1] 5.610e+00 0.17 0.86 4.710e+00 5.030e+00 5.430e+00 5.910e+00 8.130e+00 24 1.15
tau_0_param[1,2] 9.400e+00 2.57 10.79 9.800e-01 3.410e+00 6.210e+00 1.036e+01 4.795e+01 18 1.33
tau_0_param[2,1] 3.310e+00 0.21 1.18 1.570e+00 2.490e+00 3.180e+00 3.870e+00 6.240e+00 32 1.11
tau_0_param[2,2] 9.850e+00 2.50 10.52 2.570e+00 4.340e+00 6.180e+00 1.035e+01 4.702e+01 18 1.33
L_omega_0_param[1,1,1] 1.000e+00 NaN 0.00 1.000e+00 1.000e+00 1.000e+00 1.000e+00 1.000e+00 NaN NaN
L_omega_0_param[1,1,2] 0.000e+00 NaN 0.00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 NaN NaN
L_omega_0_param[1,2,1] -4.200e-01 0.02 0.16 -8.200e-01 -5.000e-01 -3.900e-01 -3.100e-01 -2.000e-01 61 1.05
L_omega_0_param[1,2,2] 8.900e-01 0.01 0.10 5.700e-01 8.700e-01 9.200e-01 9.500e-01 9.800e-01 66 1.05
L_omega_0_param[2,1,1] 1.000e+00 NaN 0.00 1.000e+00 1.000e+00 1.000e+00 1.000e+00 1.000e+00 NaN NaN
L_omega_0_param[2,1,2] 0.000e+00 NaN 0.00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 NaN NaN
L_omega_0_param[2,2,1] 3.800e-01 0.09 0.46 -6.700e-01 7.000e-02 4.800e-01 7.800e-01 9.700e-01 24 1.20
L_omega_0_param[2,2,2] 7.700e-01 0.04 0.22 2.500e-01 6.200e-01 8.300e-01 9.700e-01 1.000e+00 40 1.04
lapse_rate 0.000e+00 0.00 0.00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 300 1.00
m_0[1,1] -1.080e+00 0.00 0.00 -1.080e+00 -1.080e+00 -1.080e+00 -1.080e+00 -1.080e+00 2 1.00
m_0[1,2] -7.000e-02 0.00 0.00 -7.000e-02 -7.000e-02 -7.000e-02 -7.000e-02 -7.000e-02 2 1.00
m_0[2,1] 8.000e-02 0.00 0.00 8.000e-02 8.000e-02 8.000e-02 8.000e-02 8.000e-02 2 1.00
m_0[2,2] 6.000e-02 0.00 0.00 6.000e-02 6.000e-02 6.000e-02 6.000e-02 6.000e-02 2 1.00
S_0[1,1,1] 3.226e+01 2.30 11.55 2.220e+01 2.527e+01 2.944e+01 3.498e+01 6.616e+01 25 1.15
S_0[1,1,2] -2.924e+01 14.32 65.90 -2.576e+02 -1.875e+01 -1.070e+01 -7.020e+00 -3.040e+00 21 1.25
S_0[1,2,1] -2.924e+01 14.32 65.90 -2.576e+02 -1.875e+01 -1.070e+01 -7.020e+00 -3.040e+00 21 1.25
S_0[1,2,2] 2.048e+02 126.85 585.89 9.600e-01 1.165e+01 3.860e+01 1.074e+02 2.299e+03 21 1.24
S_0[2,1,1] 1.232e+01 1.65 10.08 2.480e+00 6.190e+00 1.010e+01 1.498e+01 3.897e+01 37 1.09
S_0[2,1,2] -6.370e+00 11.06 52.79 -1.711e+02 2.140e+00 7.560e+00 9.560e+00 1.310e+01 23 1.22
S_0[2,2,1] -6.370e+00 11.06 52.79 -1.711e+02 2.140e+00 7.560e+00 9.560e+00 1.310e+01 23 1.22
S_0[2,2,2] 2.077e+02 127.96 597.28 6.590e+00 1.885e+01 3.823e+01 1.070e+02 2.211e+03 22 1.23
[...]
I’m probably either doing something stupid in my code or am missing something fundamental about how the sampling works? Any help would be much appreciated!
Edited by @jsocolar for better syntax highlighting.