Chains differ in categorical model

This is an attempt to improve my model after suggestions from @Bob_Carpenter, but I hope this post is self-contained.

I’ve made a series of simple models and I’m generating synthetic data in Stan.

Multilogit

I took the multilogit documentation’s example and wrote a generator to match. Based on the results it seems the prior is sufficient to constrain the estimates and I don’t need to use the K-1 trick for identifiability.

Generator:

data {
  int<lower=2> K;
  int<lower=0> N;
  int<lower=1> D;
}
generated quantities {
  matrix[N, D] x;
  for (n in 1:N) {
      for (d in 1:D) {
          x[n,d] = normal_rng(0,1);
      }
  }
  matrix[D, K] beta;
  for (d in 1:D) {
      for (k in 1:K) {
          beta[d,k] = normal_rng(0,5);
      }
  }
  matrix[N, K] x_beta = x * beta;
  array[N] int<lower=1,upper=K> y;
  for (n in 1:N) y[n] = categorical_rng(softmax(x_beta[n]'));
}

Inference:

// https://mc-stan.org/docs/2_28/stan-users-guide/multi-logit.html
data {
  int<lower=2> K;
  int<lower=0> N;
  int<lower=1> D;
  array[N] int<lower=1,upper=K> y;
  matrix[N, D] x;
}
parameters {
  matrix[D, K] beta;
}
model {
  matrix[N, K] x_beta = x * beta;

  to_vector(beta) ~ normal(0, 5);

  for (n in 1:N) {
    y[n] ~ categorical(softmax(x_beta[n]'));
  }
}

So it seems the prior works well in this case.

Adding another variable

I add another variable which is unobserved, so I expect the problem to be much more difficult.

data {
  int<lower=2> K;
  int<lower=0> N;
  int<lower=1> D;
}
generated quantities {
  matrix[N, D] x;
  matrix[N, D] u;
  for (n in 1:N) {
      for (d in 1:D) {
          x[n,d] = normal_rng(0,1);
          u[n,d] = normal_rng(0,1);
      }
  }
  matrix[D, K] beta;
  matrix[D, K] gamma;
  for (d in 1:D) {
      for (k in 1:K) {
          beta[d,k] = normal_rng(0,5);
          gamma[d,k] = normal_rng(0,5);
      }
  }
  matrix[N, K] x_beta = x * beta;
  matrix[N, K] u_gamma = u * gamma;
  array[N] int<lower=1,upper=K> y;
  for (n in 1:N) y[n] = categorical_rng(softmax(x_beta[n]' + u_gamma[n]'));
}

This produces 2x3 matrices for beta and gamma ; beta=>gamma pairs might look like:

 3.06798=>2.66563   3.33377=>-1.79673   1.62864=>-1.48669
 3.41788=>0.824499   -8.223=>3.01853   -4.39284=>2.05907

I tried sampling with and without the K-1 trick described in multi-logit identifiability. In the graphs below, the colors identify the chains; I added jitter to make the draws more visible. The black dot is the true value of each parameter generated by Stan.

With K-1 trick:

data {
  int<lower=2> K;
  int<lower=0> N;
  int<lower=1> D;
  array[N] int<lower=1,upper=K> y;
  matrix[N, D] x;
}
transformed data {
  vector[D] zeros = rep_vector(0, D);
}
parameters {
  matrix[D, K-1] beta_raw;
  matrix[D, K-1] gamma_raw;
  matrix[N, D] u;
}
transformed parameters {
  matrix[D, K] beta = append_col(beta_raw, zeros);
  matrix[D, K] gamma = append_col(gamma_raw, zeros);
}
model {
  matrix[N, K] x_beta = x * beta;
  matrix[N, K] u_gamma = u * gamma;

  to_vector(u) ~ normal(0,1);
  to_vector(beta) ~ normal(0, 5);
  to_vector(gamma) ~ normal(0, 5);

  for (n in 1:N) {
    y[n] ~ categorical(softmax(x_beta[n]' + u_gamma[n]'));
  }
}

Without K-1 trick:

data {
  int<lower=2> K;
  int<lower=0> N;
  int<lower=1> D;
  array[N] int<lower=1,upper=K> y;
  matrix[N, D] x;
}
parameters {
  matrix[D, K] beta;
  matrix[D, K] gamma;
  matrix[N, D] u;
}
model {
  matrix[N, K] x_beta = x * beta;
  matrix[N, K] u_gamma = u * gamma;

  to_vector(u) ~ normal(0,1);
  to_vector(beta) ~ normal(0, 5);
  to_vector(gamma) ~ normal(0, 5);

  for (n in 1:N) {
    y[n] ~ categorical(softmax(x_beta[n]' + u_gamma[n]'));
  }
}

In the sans-trick model, the posterior distributions look pretty good but visually it is apparent that one of the chains ( #1 ) is different from the others, and the means of each chain confirm the difference.

2×3×4 Array{Float64, 3}:
[:, :, 1] =
 -2.59055  1.61997   0.706378
  3.61502  1.82376  -5.2643

[:, :, 2] =
  5.16026    0.0104177  -5.63252
 -0.674936  -3.08028     3.58979

[:, :, 3] =
  5.16026    0.0104177  -5.63252
 -0.674936  -3.08028     3.58979

[:, :, 4] =
  5.16026    0.0104177  -5.63252
 -0.674936  -3.08028     3.58979

From what I remember, it’s bad for the chains to be so different, and the R_hat diagnostic is intended to report such issues. However, stansummary doesn’t give any indication of a problem:

% stansummary /tmp/jl_A5VVFT/train*.csv
Inference for Stan model: train_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.

Warmup took (72, 72, 72, 72) seconds, 4.8 minutes total
Sampling took (74, 74, 74, 74) seconds, 4.9 minutes total

                    Mean     MCSE   StdDev     5%       50%    95%    N_Eff  N_Eff/s    R_hat

lp__            -1.2e+03      5.0       64  -1310  -1.2e+03  -1108      162  5.5e-01      1.0
accept_stat__       0.90  3.8e-03  1.1e-01   0.68      0.93   1.00  8.2e+02  2.8e+00  1.0e+00
stepsize__         0.027  1.7e-03  2.4e-03  0.026     0.026  0.032  2.0e+00  6.8e-03  2.3e+13
treedepth__          7.0      nan  1.2e-13    7.0       7.0    7.0      nan      nan      nan
n_leapfrog__         127      nan  1.8e-12    127       127    127      nan      nan      nan
divergent__         0.00      nan  0.0e+00   0.00      0.00   0.00      nan      nan      nan
energy__            2199  5.3e+00  7.1e+01   2100      2192   2323  1.8e+02  6.0e-01  1.0e+00

beta[1,1]        3.9e+00    0.084      3.1   -1.1   3.8e+00    9.0     1365  4.6e+00      1.0
beta[1,2]        2.7e+00    0.077      3.0   -2.1   2.7e+00    7.8     1563  5.3e+00      1.0
beta[1,3]       -6.5e+00     0.11      3.4    -12  -6.5e+00  -0.88      920  3.1e+00      1.0
beta[2,1]        2.7e+00    0.091      3.1   -2.4   2.5e+00    7.8     1170  4.0e+00      1.0
beta[2,2]       -2.3e+00    0.092      3.0   -7.4  -2.3e+00    2.6     1059  3.6e+00     1.00
beta[2,3]       -1.1e+00    0.088      3.0   -6.1  -1.1e+00    3.8     1174  4.0e+00     1.00
gamma[1,1]       3.2e+00      2.5      4.9   -5.6   3.8e+00     10      3.8  1.3e-02      1.5
gamma[1,2]       4.1e-01     0.36      3.4   -5.3   4.5e-01    6.0       88  3.0e-01      1.0
gamma[1,3]      -4.0e+00      2.0      4.7    -11  -4.4e+00    4.8      5.3  1.8e-02      1.4
gamma[2,1]       4.0e-01      1.4      4.4   -6.6   1.8e-01    8.0       10  3.5e-02      1.2
gamma[2,2]      -1.9e+00      1.6      3.8   -7.9  -1.9e+00    4.6      5.7  1.9e-02      1.2
gamma[2,3]       1.4e+00      2.9      5.8   -8.9   2.0e+00     10      4.1  1.4e-02      1.4
// u[...] not shown

Questions

  • Is there a benefit to using the K-1 technique if I’m already using a prior? Based on these results it doesn’t seem to help here.
  • Is it bad for the chains to be different visually and in their means? Why are the chains so different? How can I fix it?

Are both beta and gamma truly centered over the same value? In the simulation both are centered over 0 with a sd of 1. There’s an identifiability issue. The model could easily find one value with a mean of 0 and a sd of sqrt(2), which is what you get by adding the two rv together. If you simulate with distinct means that are far enough away in sd land then what does the inference look like?

Here are the results of that but it doesn’t seem to fix the issue.

If I change the parameter generation to

          beta[d,k] = normal_rng(0,1);
          gamma[d,k] = normal_rng(3,1);

then I get

> β.=>γ
2×3 Matrix{Pair{Float64, Float64}}:
   -0.118=>3.58811  -0.922827=>3.27653  -0.0236587=>2.09324
 -1.68189=>4.20688    -1.0528=>3.3796     -0.91014=>2.07714

and change the priors to

  to_vector(beta) ~ normal(0, 2);
  to_vector(gamma) ~ normal(0, 2);

R_hat is high for some variables here:

Inference for Stan model: train_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 itera>

Warmup took (44, 43, 46, 46) seconds, 3.0 minutes total
Sampling took (37, 37, 70, 71) seconds, 3.6 minutes total

                    Mean     MCSE   StdDev     5%       50%    95%    N_Eff  N_Eff/s    R_hat

lp__            -1.4e+03       12      129  -1692  -1.4e+03  -1270      113     0.52      1.0
accept_stat__       0.89  5.8e-03  1.2e-01   0.64      0.93   1.00  4.2e+02  1.9e+00  1.0e+00
stepsize__         0.050  5.4e-03  7.6e-03  0.043     0.058  0.058  2.0e+00  9.3e-03  5.2e+13
treedepth__          6.4  1.1e-02  5.0e-01    6.0       6.0    7.0  2.0e+03  9.3e+00  2.3e+00
n_leapfrog__          92  7.1e-01  3.2e+01     63        63    127  2.0e+03  9.3e+00  2.7e+00
divergent__         0.00      nan  0.0e+00   0.00      0.00   0.00      nan      nan      nan
energy__            2455  1.2e+01  1.3e+02   2264      2437   2702  1.2e+02  5.5e-01  1.0e+00

beta[1,1]        6.1e-01    0.027      1.2   -1.3   6.0e-01    2.5     1869      8.7     1.00
beta[1,2]       -1.3e+00    0.034      1.2   -3.3  -1.3e+00   0.65     1259      5.8      1.0
beta[1,3]        6.8e-01    0.027      1.2   -1.3   6.8e-01    2.6     1808      8.4     1.00
beta[2,1]       -7.9e-01    0.029      1.2   -2.7  -8.2e-01    1.2     1585      7.4      1.0
beta[2,2]        1.9e-01    0.025      1.1   -1.7   1.5e-01    2.1     1993      9.2      1.0
beta[2,3]        6.1e-01    0.027      1.2   -1.3   5.8e-01    2.6     1786      8.3      1.0
gamma[1,1]       1.0e+00      1.0      2.2   -2.8   1.2e+00    4.3      4.3    0.020      1.5
gamma[1,2]       7.8e-01     0.34      1.4   -1.5   8.0e-01    3.1       17    0.081      1.1
gamma[1,3]      -1.9e+00     0.73      2.0   -4.9  -2.0e+00    1.6      7.4    0.034      1.3
gamma[2,1]       1.5e+00     0.47      1.8   -1.4   1.5e+00    4.3       14    0.064      1.1
gamma[2,2]      -6.7e-01     0.62      1.6   -3.4  -7.1e-01    2.1      7.0    0.032      1.2
gamma[2,3]      -7.1e-01      1.2      2.3   -4.3  -7.7e-01    3.0      3.7    0.017      1.5

I’m not sure if that’s what you meant.

The output that I want is to have posterior uncertainty about whether the betas are big or the gammas are big, and for those to be negatively correlated. How can I write a better model for this?

I’m not sure this model is going to be identified in a nice way. I think you’ll be resigned to diffuse estimates and/or adding external information with other data or strong priors. There are many ways the parameters can explain this data. For example, multiplying two unknowns, if one of the parameters is increased the other should decrease, but which one? Either can move and both are consistent with this model! Take a deeper look at the parameters. There may be strong correlation and some multimodality, indicating different modes which equally explain the data.