I am fitting a multivariate mixture model but my Rhats all over the place. My data generation code is
runif_simplex <- function(T) {
x <- -log(runif(T))
x / sum(x)
}
MV_LCA <- function(NC, NY, N){
pi1 <- runif_simplex(NC)
z <- sample( x= 1:NC, size = N, replace = TRUE, prob = pi1)
mu <- matrix(nrow = NC, ncol = NY)
a <- rep(0,NY)
for(i in 1:NC){
mu[i,] <- a
a = a + 10
}
y <- matrix(nrow = N, ncol = NY)
for(n in 1:N){
y[n,] <- mvtnorm::rmvnorm(1, mean = mu[z[n],], sigma = diag(1, nrow = NY))
}
list(y = y, z = z, theta = list(pi1 = pi1, mu = mu))
}
mx_dt <- MV_LCA(3, 2, 300)
and my stan code which follows Trouble with Gaussian Mixture Model is
data {
int N; //number of observations
int K; //number of classes
int NY; //number of observed variables
vector[NY] Y[N]; //data
}
parameters {
vector[K] theta; //mixing proportions
ordered[K] mu[NY]; //mixture component means
cholesky_factor_corr[NY] L[K]; //cholesky factor of covariance
vector<lower = 0>[NY] sigma[K];
ordered[K] p;
}
transformed parameters{
vector[NY] mu_ord[K];
simplex[K] theta_n;
theta_n = softmax(theta);
for (j in 1:NY){
for (k in 1:K){
mu_ord[k, j] = mu[j, k];
}
}
}
model {
vector[K] ps;
p ~ normal(0,10);
theta ~ normal(0,.5);
for(k in 1:K){
mu[, k] ~ normal(p[k],2);
L[k] ~ lkj_corr_cholesky(5);
sigma[k] ~ std_normal();
}
for (n in 1:N){
for (k in 1:K){
ps[k] = log(theta_n[k]) + multi_normal_cholesky_lpdf(Y[n] | mu_ord[k], diag_pre_multiply(sigma[k], L[k])); //increment log probability of the gaussian
}
target += log_sum_exp(ps);
}
}
However, my class means are multimodel
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff
mu_ord[1,1] 1.2144058 0.83088225 1.0523290 -0.4652320000 -0.09427355 1.6865650 2.0437300 2.5614545 1.604073
mu_ord[1,2] 1.2260778 0.87530376 1.1111393 -0.7784914000 -0.07988217 1.7363450 2.0955275 2.6033510 1.611460
mu_ord[2,1] 6.3999033 3.61243599 4.4329833 -0.0680316275 0.22516650 9.2400950 9.6786150 10.0627025 1.505885
mu_ord[2,2] 6.5643907 3.64208352 4.4661419 0.0008965489 0.36662125 9.5848750 9.8141450 10.0677250 1.503714
mu_ord[3,1] 10.6129343 0.21754509 0.3689531 9.9882272500 10.32810000 10.5761000 10.8790250 11.3884275 2.876363
mu_ord[3,2] 10.4423360 0.30040886 0.4382617 9.8386040000 10.08040000 10.3140500 10.7916750 11.3554900 2.128343
theta[1] 0.6841677 0.48223769 0.6912677 -0.8783362500 0.26576675 0.8893390 1.1770050 1.6401905 2.054803
theta[2] -0.4230893 0.21747393 0.5467207 -1.4303990000 -0.80814150 -0.4467285 -0.0575592 0.6647487 6.319994
theta[3] -0.2323758 0.28154095 0.5220417 -1.3042827500 -0.58801025 -0.2036370 0.1531850 0.6909437 3.438169
p[1] 1.0129331 1.17940556 1.9545832 -2.8189635000 -0.38049675 1.0939900 2.4099475 4.6850590 2.746515
p[2] 6.2369546 3.09811775 3.9242490 -0.6909558750 1.65356500 8.0895250 9.3052600 10.9257425 1.604417
p[3] 10.6907368 0.02252125 1.2824686 8.2127330000 9.77637500 10.6847000 11.5349000 13.2768075 3242.713004
theta_n[1] 0.5680305 0.15664537 0.2026572 0.1312347750 0.39535300 0.6814930 0.7120465 0.7531341 1.673744
theta_n[2] 0.1977537 0.07119998 0.1151653 0.0603734925 0.11565975 0.1628290 0.2424110 0.4821426 2.616277
theta_n[3] 0.2342157 0.08849630 0.1151140 0.0691795750 0.14296950 0.1962120 0.3683120 0.4229827 1.692023
sigma[1,1] 3.4256922 1.50596224 1.8525152 0.5825078750 0.92417625 4.5851950 4.7853125 5.1148943 1.513197
sigma[1,2] 3.5140795 1.45370440 1.7889001 0.7225185750 1.10882500 4.6314350 4.8275225 5.1692545 1.514328
sigma[2,1] 0.7824114 0.14611578 0.2813299 0.2090872750 0.57900825 0.8062010 0.9796770 1.2836407 3.707128
sigma[2,2] 1.0090473 0.01169272 0.2041508 0.5988232250 0.88697600 0.9999235 1.1324675 1.4275905 304.839400
sigma[3,1] 1.5538882 0.78269777 0.9714749 0.5572269000 0.81358025 0.9933925 2.7960625 3.1485080 1.540547
sigma[3,2] 1.5859154 0.81938585 1.0188719 0.2968618000 0.84501925 1.0123400 2.8822900 3.2524710 1.546188
Rhat
mu_ord[1,1] 3.708499
mu_ord[1,2] 3.610858
mu_ord[2,1] 16.047415
mu_ord[2,2] 20.755349
mu_ord[3,1] 1.406798
mu_ord[3,2] 1.760585
theta[1] 1.849773
theta[2] 1.153572
theta[3] 1.317029
p[1] 1.438734
p[2] 3.619240
p[3] 1.000188
theta_n[1] 3.000197
theta_n[2] 1.512294
theta_n[3] 2.820832
sigma[1,1] 10.255679
sigma[1,2] 9.844162
sigma[2,1] 1.295773
sigma[2,2] 1.022352
sigma[3,1] 5.794911
sigma[3,2] 5.487649