Divergent transitions when providing optional parameter (when trying to infer multivariate Gaussians)

(Apologies for the long post)

Goal: to model changes in human speech perception based on recently experienced speech input. More specifically, the program I posted below aims to infer listeners’ prior beliefs about the cue distributions (means and covariance matrices) corresponding to 2 or more multivariate Gaussian categories (each corresponding to a sound category, e.g., /p/ vs. /b/). This is done under the assumption that listeners a) start with a set of prior beliefs about the mean and covariance matrices of the categories (modeled as Normal-Inverse-Wishart priors; Murphy 2012) and b) update those beliefs based on the sufficient statistics of recently experienced input (‘exposure’).

What I need help with: The program seems to work but—curious to me—when I provide additional optional information to the program (e.g., when I tell the model where the category means are, so that it only need to infer the category covariance matrices), I mostly get divergent transitions. This is the case even though the additional information is ‘correct’ , i.e., when I use data that match the model assumptions and for which I know the ground truth. The few remaining posterior samples ‘make sense’ but I’m trying to understand whether there’s something I’m missing in the way I’m handling optional parameters in this program that causes those divergent transitions. It seems stan is still sampling from those parameters even when they are user-provided. Apologies if this is a naive question.

The input to the program is a combination of:

  1. sufficient statistics (number of observations, k-dimensional mean, and kxk sum-of-square matrix) for each category in each exposure conditions
  2. counts of subsequent categorization responses (category 1, 2, …) on a series of test tokens (k-dimensional vectors).
  3. Optionally, users can provide information about the prior mean and/or covariance matrix.

The model aims to infer the shared prior beliefs that explain those test responses given the exposure statistics (under lots of simplifying assumptions). Here’s a screen shot of belief-updating the model essentially inverses in order to infer the prior beliefs (Murphy, 2012: p 134), where \mu and \Sigma are the category mean and covariance matrix, and m, S, \kappa, and \nu are the parameters of the Normal-Inverse-Wishart model:

Here is the stan code:

data {
  int M;                        // number of categories
  int L;                        // number of grouping levels (e.g. subjects)
  int K;                        // number of features

  matrix[M,L] N;                // number of observations per category (m) and group (l)
  vector[K] x_mean[M,L];        // means for each category (m) and group (l)
  cov_matrix[K] x_ss[M,L];      // sum of uncentered squares matrix for each category (m) and group (l)

  int N_test;                   // number of test trials
  vector[K] x_test[N_test];     // locations of test trials
  int y_test[N_test];           // group label of test trials
  int z_test_counts[N_test,M];  // responses for test trials

  int<lower=0, upper=1> m_0_known;
  int<lower=0, upper=1> S_0_known;
  vector[m_0_known ? K : 0] m_0_data[m_0_known ? M : 0];        // optional: user provided m_0 (prior mean of means)
  cov_matrix[S_0_known ? K : 0] S_0_data[S_0_known ? M : 0];    // optional: user provided S_0 (prior scatter matrix of mean)

  real<lower=0> tau_scale;      // scale of cauchy prior for variances of m_0 (set to zero to ignore)
  real<lower=0> L_omega_scale;  // scale of LKJ prior for correlation of variance of m_0 (set to zero to ignore)
}

transformed data {
  real sigma_kappanu;

  /* Scale for the prior of kappa/nu_0. In order to deal with input that does not contain observations
     (in which case n_each == 0), we set the minimum value for SD to 10. */
  sigma_kappanu = max(N) > 0 ? max(N) * 4 : 10;
}

parameters {
  // these are all shared across groups (same prior beliefs):
  real<lower=K> kappa_0;                  // prior pseudocount for category mu
  real<lower=K + 1> nu_0;                 // prior pseudocount for category Sigma

  vector[K] m_0_param[m_0_known ? 0 : M]; // prior mean of means
  vector<lower=0>[K] m_0_tau;             // prior variances of m_0 
  cholesky_factor_corr[K] m_0_L_omega;    // prior correlations of variances of m_0 (in cholesky form) 

  vector<lower=0>[K] tau_0_param[S_0_known ? 0 : M];          // standard deviations of prior scatter matrix S_0
  cholesky_factor_corr[K] L_omega_0_param[S_0_known ? 0 : M]; // correlation matrix of prior scatter matrix S_0 (in cholesky form)

  real<lower=0, upper=1> lapse_rate;
}

transformed parameters {
  vector[K] m_0[M];                    // prior mean of means m_0
  cov_matrix[K] S_0[M];                // prior scatter matrix S_0

  // updated beliefs depend on input and group
  real<lower=K> kappa_n[M,L];          // updated mean pseudocount
  real<lower=K> nu_n[M,L];             // updated sd pseudocount
  vector[K] m_n[M,L];                  // updated expected mean
  cov_matrix[K] S_n[M,L];              // updated expected scatter matrix
  cov_matrix[K] t_scale[M,L];          // scale matrix of predictive t distribution

  simplex[M] p_test_conj[N_test];
  vector[M] log_p_test_conj[N_test];

  if (m_0_known) {
    m_0 = m_0_data;
  } else {
    m_0 = m_0_param;
  }
  if (S_0_known) {
    S_0 = S_0_data;
  }

  // update NIW parameters according to conjugate updating rules are taken from
  // Murphy (2007, p. 136)
  for (cat in 1:M) {
    if (!S_0_known) {
      // Get S_0 from its components: correlation matrix and vector of standard deviations
      S_0[cat] = quad_form_diag(multiply_lower_tri_self_transpose(L_omega_0_param[cat]), tau_0_param[cat]);
    }
    for (group in 1:L) {
      if (N[cat,group] > 0 ) {
        kappa_n[cat,group] = kappa_0 + N[cat,group];
        nu_n[cat,group] = nu_0 + N[cat,group];
        m_n[cat,group] = (kappa_0 * m_0[cat] + N[cat,group] * x_mean[cat,group]) /
                        kappa_n[cat,group];
        S_n[cat,group] = S_0[cat] +
                        x_ss[cat,group] +
                        kappa_0 * m_0[cat] * m_0[cat]' -
                        kappa_n[cat,group] * m_n[cat,group] * m_n[cat,group]';
      } else {
        kappa_n[cat,group] = kappa_0;
        nu_n[cat,group] = nu_0;
        m_n[cat,group] = m_0[cat];
        S_n[cat,group] = S_0[cat];
      }

      t_scale[cat,group] = S_n[cat,group] * (kappa_n[cat,group] + 1) /
                                              (kappa_n[cat,group] * (nu_n[cat,group] - K + 1));
    }
  }

  // compute category probabilities for each of the test stimuli
  for (j in 1:N_test) {
    int group;
    group = y_test[j];
    // calculate un-normalized log prob for each category
    for (cat in 1:M) {
      log_p_test_conj[j,cat] = multi_student_t_lpdf(x_test[j] |
                                              nu_n[cat,group] - K + 1,
                                              m_n[cat,group],
                                              t_scale[cat,group]);
    }
    // normalize and store actual probs in simplex
    p_test_conj[j] = exp(log_p_test_conj[j] - log_sum_exp(log_p_test_conj[j]));
  }
}

model {
  vector[M] lapsing_probs;

  lapsing_probs = rep_vector(lapse_rate / M, M);

  kappa_0 ~ normal(0, sigma_kappanu);
  nu_0 ~ normal(0, sigma_kappanu);

  /* Specifying prior for m_0:
     - If no scale for variances (tau) of m_0 is user-specified use weakly regularizing
       scale (5) for variances of mean.
     - If no scale for LKJ prior over correlation matrix of m_0 is user-specified use
       scale 1 to set uniform prior over correlation matrices. */
  if (!m_0_known) {
    m_0_tau ~ cauchy(0, tau_scale > 0 ? tau_scale : 5);
    m_0_L_omega ~ lkj_corr_cholesky(L_omega_scale > 0 ? L_omega_scale : 1);
    m_0_param ~ multi_normal_cholesky(rep_vector(0, K), diag_pre_multiply(m_0_tau, m_0_L_omega));
  }

  /* Specifying prior for components of S_0: */
  if (!S_0_known) {
      for (cat in 1:M) {
        tau_0_param[cat] ~ cauchy(0, tau_scale > 0 ? tau_scale : 10);
        L_omega_0_param[cat] ~ lkj_corr_cholesky(L_omega_scale > 0 ? L_omega_scale : 1);
      }
  }

  for (i in 1:N_test) {
    z_test_counts[i] ~ multinomial(p_test_conj[i] * (1-lapse_rate) + lapsing_probs);
  }

}

generated quantities {
  if (!m_0_known) {
    matrix[K,K] m_0_cor;
    matrix[K,K] m_0_cov;

    m_0_cor = multiply_lower_tri_self_transpose(m_0_L_omega);
    m_0_cov = quad_form_diag(m_0_cor, m_0_tau);
  }
}

And here is an example input (with known ground truth; these data were generated in a way that meets all the assumptions of the model):

example-input.RData (5.3 KB)

If I run this model, it seems to converge (with large uncertainty about the parameter, which makes sense for this input). My question is about what happens when I also provide the model with m_0 (as in this alternative example input: example-input-with-m0.RData). For this example, I change m_0_known to 1 and changed m_0_data to the correct mean (the prior mean the data was generated from):

        [,1]     [,2]
[1,] -1.0817 -0.06806
[2,]  0.0769  0.05767

Then the model fit results in lots of divergent transitions. The output still makes sense (approximates the ground truth the data was generated from; e.g., the data was generated from kappa = nu = 4) but there are very few samples in the model:

Inference for Stan model: mvg_conj_sufficient_stats_lapse.
4 chains, each with iter=4000; warmup=2000; thin=1; 
post-warmup draws per chain=2000, total post-warmup draws=8000.

                              mean se_mean     sd        2.5%         25%         50%         75%       97.5% n_eff Rhat
kappa_0                  5.120e+00    0.09   0.74   3.870e+00   4.600e+00   5.030e+00   5.600e+00   6.710e+00    66 1.04
nu_0                     3.232e+01    3.95  27.41   3.110e+00   1.301e+01   2.536e+01   4.345e+01   1.043e+02    48 1.07
m_0_tau[1]              9.141e+307     NaN    Inf  4.077e+306  4.374e+307  9.291e+307  1.374e+308  1.763e+308   NaN  NaN
m_0_tau[2]              8.914e+307     NaN    Inf  3.342e+306  4.641e+307  8.902e+307  1.324e+308  1.755e+308   NaN  NaN
m_0_L_omega[1,1]         1.000e+00     NaN   0.00   1.000e+00   1.000e+00   1.000e+00   1.000e+00   1.000e+00   NaN  NaN
m_0_L_omega[1,2]         0.000e+00     NaN   0.00   0.000e+00   0.000e+00   0.000e+00   0.000e+00   0.000e+00   NaN  NaN
m_0_L_omega[2,1]         1.300e-01    0.06   0.58  -9.200e-01  -3.700e-01   2.000e-01   6.300e-01   9.900e-01    91 1.03
m_0_L_omega[2,2]         7.700e-01    0.02   0.24   1.600e-01   6.200e-01   8.600e-01   9.600e-01   1.000e+00    99 1.02
tau_0_param[1,1]         5.610e+00    0.17   0.86   4.710e+00   5.030e+00   5.430e+00   5.910e+00   8.130e+00    24 1.15
tau_0_param[1,2]         9.400e+00    2.57  10.79   9.800e-01   3.410e+00   6.210e+00   1.036e+01   4.795e+01    18 1.33
tau_0_param[2,1]         3.310e+00    0.21   1.18   1.570e+00   2.490e+00   3.180e+00   3.870e+00   6.240e+00    32 1.11
tau_0_param[2,2]         9.850e+00    2.50  10.52   2.570e+00   4.340e+00   6.180e+00   1.035e+01   4.702e+01    18 1.33
L_omega_0_param[1,1,1]   1.000e+00     NaN   0.00   1.000e+00   1.000e+00   1.000e+00   1.000e+00   1.000e+00   NaN  NaN
L_omega_0_param[1,1,2]   0.000e+00     NaN   0.00   0.000e+00   0.000e+00   0.000e+00   0.000e+00   0.000e+00   NaN  NaN
L_omega_0_param[1,2,1]  -4.200e-01    0.02   0.16  -8.200e-01  -5.000e-01  -3.900e-01  -3.100e-01  -2.000e-01    61 1.05
L_omega_0_param[1,2,2]   8.900e-01    0.01   0.10   5.700e-01   8.700e-01   9.200e-01   9.500e-01   9.800e-01    66 1.05
L_omega_0_param[2,1,1]   1.000e+00     NaN   0.00   1.000e+00   1.000e+00   1.000e+00   1.000e+00   1.000e+00   NaN  NaN
L_omega_0_param[2,1,2]   0.000e+00     NaN   0.00   0.000e+00   0.000e+00   0.000e+00   0.000e+00   0.000e+00   NaN  NaN
L_omega_0_param[2,2,1]   3.800e-01    0.09   0.46  -6.700e-01   7.000e-02   4.800e-01   7.800e-01   9.700e-01    24 1.20
L_omega_0_param[2,2,2]   7.700e-01    0.04   0.22   2.500e-01   6.200e-01   8.300e-01   9.700e-01   1.000e+00    40 1.04
lapse_rate               0.000e+00    0.00   0.00   0.000e+00   0.000e+00   0.000e+00   0.000e+00   0.000e+00   300 1.00
m_0[1,1]                -1.080e+00    0.00   0.00  -1.080e+00  -1.080e+00  -1.080e+00  -1.080e+00  -1.080e+00     2 1.00
m_0[1,2]                -7.000e-02    0.00   0.00  -7.000e-02  -7.000e-02  -7.000e-02  -7.000e-02  -7.000e-02     2 1.00
m_0[2,1]                 8.000e-02    0.00   0.00   8.000e-02   8.000e-02   8.000e-02   8.000e-02   8.000e-02     2 1.00
m_0[2,2]                 6.000e-02    0.00   0.00   6.000e-02   6.000e-02   6.000e-02   6.000e-02   6.000e-02     2 1.00
S_0[1,1,1]               3.226e+01    2.30  11.55   2.220e+01   2.527e+01   2.944e+01   3.498e+01   6.616e+01    25 1.15
S_0[1,1,2]              -2.924e+01   14.32  65.90  -2.576e+02  -1.875e+01  -1.070e+01  -7.020e+00  -3.040e+00    21 1.25
S_0[1,2,1]              -2.924e+01   14.32  65.90  -2.576e+02  -1.875e+01  -1.070e+01  -7.020e+00  -3.040e+00    21 1.25
S_0[1,2,2]               2.048e+02  126.85 585.89   9.600e-01   1.165e+01   3.860e+01   1.074e+02   2.299e+03    21 1.24
S_0[2,1,1]               1.232e+01    1.65  10.08   2.480e+00   6.190e+00   1.010e+01   1.498e+01   3.897e+01    37 1.09
S_0[2,1,2]              -6.370e+00   11.06  52.79  -1.711e+02   2.140e+00   7.560e+00   9.560e+00   1.310e+01    23 1.22
S_0[2,2,1]              -6.370e+00   11.06  52.79  -1.711e+02   2.140e+00   7.560e+00   9.560e+00   1.310e+01    23 1.22
S_0[2,2,2]               2.077e+02  127.96 597.28   6.590e+00   1.885e+01   3.823e+01   1.070e+02   2.211e+03    22 1.23
[...]

I’m probably either doing something stupid in my code or am missing something fundamental about how the sampling works? Any help would be much appreciated!

Edited by @jsocolar for better syntax highlighting.

1 Like

Much better than too short! It’s very difficult to debug models without a detailed description.

That sounds unusual indeed. I would expect additional information to help narrow the posterior into a well-behaved region. But not impossible: the narrower posterior could have more difficult shape.
That’s not the case here though; looking through your code I see a serious bug.
Let’s see:

First, I notice that m_0_param is absent m_0_known is true (sounds right) but m_0_tau is always present? That’s a bit unexpected…what role does m_0_tau play in the model?

And here we see that the whole thing is inside the if: m_0_tau gets a prior only if m_0_known is false. When m_0_known is true the posterior for m_0_tau is just the default improper flat prior. That’s not good. Without an upper bound the flat distribution basically says that for any real number x, mu_0_tau is likely to be larger than x.

And here’s the result: the sampler produces extremely large numbers.
The largest number with a floating-point representation is around 2\times10^{308}. Every time the sampler tries to exceed that limit you get a divergent transition.

3 Likes

Dear Niko, apologies for just seeing this now! The notification went to spam. Thank you for taking the time to go through this!

[EDIT: I had originally switched 0 and K in the newly added code]

Your point makes perfect sense. Both m_0_tau and m_0_L_omega, which jointly describe the priors about the distribution of any of the m_0s through m_0_param ~ multi_normal_cholesky(rep_vector(0, K), diag_pre_multiply(m_0_tau, m_0_L_omega)), should only be defined if m_0_known = 0. I have changed those parts of the code to:

  vector[K] m_0_param[m_0_known ? 0 : M];                 // prior mean of means
  vector<lower=0>[m_0_known ? 0 : K] m_0_tau;             // prior variances of m_0 
  cholesky_factor_corr[m_0_known ? 0 : K] m_0_L_omega;    // prior correlations of variances of m_0 (in cholesky form)

I’m now running a few tests with the new code, and will post an update here once that is done. Thank you!

Works like a charm now. Here’s an example run for m_0_known = 1:

                           mean se_mean     sd     2.5%      25%      50%      75%    97.5% n_eff Rhat
kappa_0                    5.12    0.02   0.71     3.84     4.60     5.09     5.60     6.60  1185 1.00
nu_0                      31.15    0.75  21.67     4.16    14.62    26.96    42.76    82.58   828 1.00
tau_0_param[1,1]           5.53    0.03   0.68     4.75     5.10     5.41     5.82     6.89   656 1.01
tau_0_param[1,2]           7.74    0.40   8.62     0.99     2.89     5.10     9.29    29.84   466 1.00
tau_0_param[2,1]           3.25    0.03   0.90     1.84     2.60     3.17     3.78     5.14   807 1.01
tau_0_param[2,2]           8.25    0.38   8.26     2.81     4.03     5.37     9.22    29.33   480 1.00
L_omega_0_param[1,1,1]     1.00     NaN   0.00     1.00     1.00     1.00     1.00     1.00   NaN  NaN
L_omega_0_param[1,1,2]     0.00     NaN   0.00     0.00     0.00     0.00     0.00     0.00   NaN  NaN
L_omega_0_param[1,2,1]    -0.43    0.01   0.16    -0.89    -0.49    -0.38    -0.32    -0.21   531 1.01
L_omega_0_param[1,2,2]     0.88    0.01   0.12     0.47     0.87     0.92     0.95     0.98   403 1.01
L_omega_0_param[2,1,1]     1.00     NaN   0.00     1.00     1.00     1.00     1.00     1.00   NaN  NaN
L_omega_0_param[2,1,2]     0.00     NaN   0.00     0.00     0.00     0.00     0.00     0.00   NaN  NaN
L_omega_0_param[2,2,1]     0.44    0.02   0.43    -0.55     0.16     0.54     0.78     0.98   486 1.01
L_omega_0_param[2,2,2]     0.76    0.01   0.23     0.22     0.62     0.82     0.95     1.00   728 1.00
lapse_rate                 0.00    0.00   0.00     0.00     0.00     0.00     0.00     0.00  2235 1.00
m_0[1,1]                  -1.08    0.00   0.00    -1.08    -1.08    -1.08    -1.08    -1.08     2 1.00
m_0[1,2]                  -0.07    0.00   0.00    -0.07    -0.07    -0.07    -0.07    -0.07     2 1.00
m_0[2,1]                   0.08    0.00   0.00     0.08     0.08     0.08     0.08     0.08     2 1.00
m_0[2,2]                   0.06    0.00   0.00     0.06     0.06     0.06     0.06     0.06     2 1.00
S_0[1,1,1]                31.07    0.39  10.05    22.60    25.97    29.27    33.87    47.47   662 1.01
S_0[1,1,2]               -21.51    2.79  68.61  -112.88   -16.24    -9.20    -6.50    -3.32   603 1.00
S_0[1,2,1]               -21.51    2.79  68.61  -112.88   -16.24    -9.20    -6.50    -3.32   603 1.00
S_0[1,2,2]               134.22   23.43 581.02     0.99     8.33    26.00    86.34   890.39   615 1.00
S_0[2,1,1]                11.39    0.25   6.95     3.38     6.74    10.08    14.31    26.45   799 1.01
S_0[2,1,2]                 0.27    1.81  46.02   -57.38     4.07     8.17     9.93    12.78   644 1.00
S_0[2,2,1]                 0.27    1.81  46.02   -57.38     4.07     8.17     9.93    12.78   644 1.00
S_0[2,2,2]               136.28   24.17 599.41     7.91    16.21    28.89    85.01   860.01   615 1.00

Thank you, @nhuurre.