I’m trying to fit a gaussian mixture model following the approach here.
data {
int D; // number of dimensions
int K; // number of gaussians
int N; // number of data
vector[D] y[N]; // data
}
parameters {
simplex[K] theta; // mixing proportions
ordered[D] mu[K]; // mixture component means
cholesky_factor_corr[D] L[K]; // cholesky factor of covariance
}
model {
real ps[K];
for (k in 1:K) {
mu[k] ~ normal(0,3);
L[k] ~ lkj_corr_cholesky(4);
}
for (n in 1:N) {
for (k in 1:K) {
//increment log probability of the gaussian
ps[k] = log(theta[k]) + multi_normal_cholesky_lpdf(y[n] | mu[k], L[k]);
}
target += log_sum_exp(ps);
}
}
I’ve been sanity checking on fake data as follows:
library(MASS)
library(rstan)
library(readr)
library(here)
n_per_clus <- 30
n_clus <- 3
d <- 4
sigma <- 0.1 * diag(4)
mu_1 <- rep(0, 4)
mu_2 <- rep(3, 4)
mu_3 <- rep(6, 4)
clus_1 <- mvrnorm(30, mu_1, sigma)
clus_2 <- mvrnorm(30, mu_2, sigma)
clus_3 <- mvrnorm(30, mu_3, sigma)
mixture_data <- list(
y = rbind(clus_1, clus_2, clus_3),
N = n_per_clus * n_clus,
D = d,
K = 3
)
gmm <- stan_model(here("stan", "gaussian_mixture_model.stan"))
fit <- sampling(
gmm,
data = mixture_data,
chains = 3,
cores = 2,
iter = 1000,
control = list(adapt_delta = 0.98)
)
print(fit)
which results in the following
Inference for Stan model: 5211a6688a83a26504506fd103be7703.
3 chains, each with iter=1000; warmup=500; thin=1;
post-warmup draws per chain=500, total post-warmup draws=1500.
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
mu[1,1] -0.16 0.00 0.10 -0.37 -0.22 -0.15 -0.09 -0.01 440 1.00
mu[1,2] 0.00 0.04 0.10 -0.19 -0.06 -0.01 0.05 0.24 6 1.16
mu[1,3] 0.07 0.03 0.09 -0.08 0.01 0.05 0.11 0.28 9 1.11
mu[1,4] 0.15 0.02 0.09 0.02 0.08 0.13 0.20 0.38 16 1.08
mu[2,1] -0.57 2.99 3.96 -7.02 -3.67 -1.60 4.40 4.64 2 2.47
mu[2,2] 0.94 2.12 2.86 -3.65 -1.45 0.29 4.42 4.66 2 2.25
mu[2,3] 2.10 1.47 2.22 -2.13 0.38 1.93 4.46 4.70 2 1.66
mu[2,4] 3.56 0.59 1.85 -0.45 2.36 4.29 4.61 7.04 10 1.10
mu[3,1] 1.96 2.93 3.76 -6.54 -1.49 4.38 4.51 4.70 2 3.25
mu[3,2] 2.74 2.05 2.67 -3.32 0.43 4.40 4.53 4.73 2 2.86
mu[3,3] 3.32 1.42 1.96 -1.46 2.13 4.42 4.55 4.76 2 2.11
mu[3,4] 4.07 0.58 1.35 0.33 4.25 4.51 4.63 6.09 6 1.19
Samples were drawn using NUTS(diag_e) at Sun Nov 25 19:10:44 2018.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
where the Rhat is fairly concerning and I can find label switching in the trace plots. I asked about this in the MSW slack and got recommendations to implement an ordering on the means, which (I believe?) is already implemented here because of the ordered[D] mu[K]
line in the model code.
The other piece of advice I got was to initialize with kmeans or similar, so I tried:
gmm2 <- stan_model(here("stan", "gaussian_mixture_model2.stan"))
fit2 <- sampling(
gmm2,
data = mixture_data,
chains = 3,
cores = 2,
iter = 1000,
control = list(adapt_delta = 0.98),
init = list(
list(mu = rbind(mu_1, mu_2, mu_3)),
list(mu = rbind(mu_1, mu_2, mu_3)),
list(mu = rbind(mu_1, mu_2, mu_3))
)
)
where now the means aren’t ordered because that was throwing an error about getting the wrong type. Anyway, I still end up with:
Inference for Stan model: gaussian_mixture_model2.
3 chains, each with iter=1000; warmup=500; thin=1;
post-warmup draws per chain=500, total post-warmup draws=1500.
mean se_mean sd 2.5% 25% 50% 75%
theta[1] 0.33 0.00 0.05 0.24 0.30 0.33 0.36
theta[2] 0.34 0.00 0.05 0.24 0.30 0.34 0.37
theta[3] 0.33 0.00 0.05 0.24 0.30 0.33 0.36
mu[1,1] -0.04 0.01 0.17 -0.39 -0.16 -0.04 0.08
mu[1,2] 0.03 0.01 0.18 -0.31 -0.10 0.03 0.15
mu[1,3] 0.00 0.01 0.18 -0.35 -0.13 0.00 0.12
mu[1,4] 0.07 0.01 0.18 -0.28 -0.05 0.07 0.19
mu[2,1] 3.03 0.01 0.20 2.66 2.90 3.03 3.15
mu[2,2] 3.06 0.01 0.19 2.69 2.93 3.05 3.18
mu[2,3] 2.95 0.01 0.20 2.56 2.82 2.95 3.08
mu[2,4] 3.07 0.01 0.19 2.71 2.94 3.07 3.19
mu[3,1] 6.06 0.03 0.19 5.68 5.92 6.06 6.19
mu[3,2] 5.93 0.03 0.19 5.56 5.80 5.93 6.07
mu[3,3] 5.95 0.03 0.19 5.56 5.82 5.95 6.08
mu[3,4] 5.97 0.03 0.19 5.57 5.84 5.96 6.10
L[1,1,1] 1.00 NaN 0.00 1.00 1.00 1.00 1.00
L[1,1,2] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[1,1,3] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[1,1,4] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[1,2,1] -0.01 0.48 0.81 -0.91 -0.83 -0.52 0.85
L[1,2,2] 0.57 0.00 0.12 0.39 0.49 0.55 0.63
L[1,2,3] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[1,2,4] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[1,3,1] 0.24 0.24 0.77 -0.91 -0.80 0.75 0.84
L[1,3,2] -0.06 0.08 0.31 -0.58 -0.30 -0.13 0.22
L[1,3,3] 0.49 0.00 0.09 0.34 0.42 0.47 0.54
L[1,3,4] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[1,4,1] -0.50 0.22 0.62 -0.92 -0.87 -0.81 -0.63
L[1,4,2] 0.24 0.05 0.24 -0.24 0.09 0.25 0.40
L[1,4,3] -0.25 0.04 0.14 -0.55 -0.33 -0.24 -0.15
L[1,4,4] 0.40 0.00 0.07 0.28 0.35 0.39 0.44
L[2,1,1] 1.00 NaN 0.00 1.00 1.00 1.00 1.00
L[2,1,2] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[2,1,3] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[2,1,4] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[2,2,1] 0.80 0.00 0.09 0.59 0.78 0.83 0.86
L[2,2,2] 0.58 0.00 0.10 0.42 0.51 0.57 0.63
L[2,2,3] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[2,2,4] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[2,3,1] 0.82 0.00 0.10 0.62 0.79 0.84 0.87
L[2,3,2] 0.30 0.00 0.13 0.06 0.21 0.29 0.37
L[2,3,3] 0.46 0.00 0.08 0.33 0.40 0.45 0.50
L[2,3,4] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[2,4,1] 0.86 0.00 0.06 0.70 0.84 0.87 0.90
L[2,4,2] 0.32 0.00 0.12 0.10 0.24 0.31 0.39
L[2,4,3] 0.06 0.00 0.10 -0.14 0.00 0.06 0.13
L[2,4,4] 0.35 0.00 0.06 0.25 0.31 0.34 0.38
L[3,1,1] 1.00 NaN 0.00 1.00 1.00 1.00 1.00
L[3,1,2] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[3,1,3] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[3,1,4] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[3,2,1] 0.88 0.00 0.06 0.73 0.86 0.89 0.91
L[3,2,2] 0.47 0.00 0.09 0.33 0.41 0.46 0.51
L[3,2,3] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[3,2,4] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[3,3,1] -0.33 0.65 0.80 -0.95 -0.91 -0.87 0.76
L[3,3,2] -0.01 0.12 0.19 -0.32 -0.15 -0.04 0.11
L[3,3,3] 0.46 0.05 0.10 0.31 0.38 0.44 0.52
L[3,3,4] 0.00 NaN 0.00 0.00 0.00 0.00 0.00
L[3,4,1] -0.25 0.60 0.78 -0.92 -0.87 -0.80 0.75
L[3,4,2] -0.12 0.11 0.20 -0.48 -0.26 -0.14 0.01
L[3,4,3] 0.21 0.08 0.16 -0.12 0.11 0.20 0.31
L[3,4,4] 0.45 0.01 0.09 0.32 0.39 0.43 0.49
lp__ -389.50 1.86 5.25 -400.88 -393.05 -389.08 -385.65
97.5% n_eff Rhat
theta[1] 0.43 2059 1.00
theta[2] 0.44 1758 1.00
theta[3] 0.43 1973 1.00
mu[1,1] 0.29 816 1.00
mu[1,2] 0.36 800 1.00
mu[1,3] 0.34 832 1.00
mu[1,4] 0.43 872 1.00
mu[2,1] 3.44 728 1.00
mu[2,2] 3.45 814 1.00
mu[2,3] 3.35 824 1.00
mu[2,4] 3.47 803 1.00
mu[3,1] 6.41 43 1.05
mu[3,2] 6.31 54 1.04
mu[3,3] 6.30 35 1.05
mu[3,4] 6.32 48 1.04
L[1,1,1] 1.00 NaN NaN
L[1,1,2] 0.00 NaN NaN
L[1,1,3] 0.00 NaN NaN
L[1,1,4] 0.00 NaN NaN
L[1,2,1] 0.91 3 1.64
L[1,2,2] 0.88 645 1.00
L[1,2,3] 0.00 NaN NaN
L[1,2,4] 0.00 NaN NaN
L[1,3,1] 0.91 10 1.46
L[1,3,2] 0.54 14 1.20
L[1,3,3] 0.70 767 1.01
L[1,3,4] 0.00 NaN NaN
L[1,4,1] 0.83 8 1.71
L[1,4,2] 0.71 22 1.14
L[1,4,3] 0.01 15 1.07
L[1,4,4] 0.57 961 1.00
L[2,1,1] 1.00 NaN NaN
L[2,1,2] 0.00 NaN NaN
L[2,1,3] 0.00 NaN NaN
L[2,1,4] 0.00 NaN NaN
L[2,2,1] 0.91 714 1.00
L[2,2,2] 0.81 859 1.00
L[2,2,3] 0.00 NaN NaN
L[2,2,4] 0.00 NaN NaN
L[2,3,1] 0.92 516 1.01
L[2,3,2] 0.56 898 1.00
L[2,3,3] 0.65 843 1.01
L[2,3,4] 0.00 NaN NaN
L[2,4,1] 0.94 801 1.00
L[2,4,2] 0.57 859 1.00
L[2,4,3] 0.26 1185 1.00
L[2,4,4] 0.50 1498 1.00
L[3,1,1] 1.00 NaN NaN
L[3,1,2] 0.00 NaN NaN
L[3,1,3] 0.00 NaN NaN
L[3,1,4] 0.00 NaN NaN
L[3,2,1] 0.94 545 1.00
L[3,2,2] 0.68 797 1.00
L[3,2,3] 0.00 NaN NaN
L[3,2,4] 0.00 NaN NaN
L[3,3,1] 0.90 2 12.13
L[3,3,2] 0.42 3 1.44
L[3,3,3] 0.70 4 1.26
L[3,3,4] 0.00 NaN NaN
L[3,4,1] 0.88 2 3.08
L[3,4,2] 0.32 3 1.33
L[3,4,3] 0.54 4 1.26
L[3,4,4] 0.66 183 1.02
lp__ -380.48 8 1.17
Samples were drawn using NUTS(diag_e) at Sun Nov 25 19:20:57 2018.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
which improves things, but still hasn’t really gotten this to working.
Is there anything else I do to make this sample more reliably? I’d really like to fit GMMs with a whole range of K (from about 5 to 75) and then use loo
to determine the number of clusters, but if each individual fit is this finicky I’m not sure this is the right approach.