Mixture Models

I am fitting a multivariate mixture model but my Rhats all over the place. My data generation code is


runif_simplex <- function(T) {
  x <- -log(runif(T))
  x / sum(x)
}

MV_LCA <- function(NC, NY, N){
  pi1 <- runif_simplex(NC)
  
  z <- sample( x= 1:NC, size = N, replace = TRUE, prob = pi1)
  
  mu <- matrix(nrow = NC, ncol = NY)
  
  a <- rep(0,NY)
  
  for(i in 1:NC){
    mu[i,] <- a
     a = a + 10
  }
  
  y <- matrix(nrow = N, ncol = NY)
  
  for(n in 1:N){
    y[n,] <- mvtnorm::rmvnorm(1, mean = mu[z[n],], sigma = diag(1, nrow = NY))
  }
  
  list(y = y, z = z, theta = list(pi1 = pi1, mu = mu))
}

mx_dt <- MV_LCA(3, 2, 300)

and my stan code which follows Trouble with Gaussian Mixture Model is

data {
 int N; //number of observations
 int K; //number of classes
 int NY; //number of observed variables
 vector[NY] Y[N]; //data
}

parameters {
 vector[K] theta; //mixing proportions
 ordered[K] mu[NY]; //mixture component means
 cholesky_factor_corr[NY] L[K]; //cholesky factor of covariance
 vector<lower = 0>[NY] sigma[K];
 
 ordered[K] p;
}

transformed parameters{
  vector[NY] mu_ord[K];
   
  simplex[K] theta_n;
  
  theta_n = softmax(theta);

  for (j in 1:NY){
    for (k in 1:K){
    mu_ord[k, j] = mu[j, k];
    }
  }
}

model {
 vector[K] ps;
 
p ~ normal(0,10);
theta ~ normal(0,.5);
 
 for(k in 1:K){
  mu[, k] ~ normal(p[k],2);
 L[k] ~ lkj_corr_cholesky(5);
 sigma[k] ~ std_normal();
 }
 

 for (n in 1:N){
 for (k in 1:K){
 ps[k] = log(theta_n[k]) + multi_normal_cholesky_lpdf(Y[n] | mu_ord[k], diag_pre_multiply(sigma[k], L[k])); //increment log probability of the gaussian
 }
 target += log_sum_exp(ps);
 }

}

However, my class means are multimodel

                  mean    se_mean        sd          2.5%         25%        50%        75%      97.5%       n_eff
mu_ord[1,1]  1.2144058 0.83088225 1.0523290 -0.4652320000 -0.09427355  1.6865650  2.0437300  2.5614545    1.604073
mu_ord[1,2]  1.2260778 0.87530376 1.1111393 -0.7784914000 -0.07988217  1.7363450  2.0955275  2.6033510    1.611460
mu_ord[2,1]  6.3999033 3.61243599 4.4329833 -0.0680316275  0.22516650  9.2400950  9.6786150 10.0627025    1.505885
mu_ord[2,2]  6.5643907 3.64208352 4.4661419  0.0008965489  0.36662125  9.5848750  9.8141450 10.0677250    1.503714
mu_ord[3,1] 10.6129343 0.21754509 0.3689531  9.9882272500 10.32810000 10.5761000 10.8790250 11.3884275    2.876363
mu_ord[3,2] 10.4423360 0.30040886 0.4382617  9.8386040000 10.08040000 10.3140500 10.7916750 11.3554900    2.128343
theta[1]     0.6841677 0.48223769 0.6912677 -0.8783362500  0.26576675  0.8893390  1.1770050  1.6401905    2.054803
theta[2]    -0.4230893 0.21747393 0.5467207 -1.4303990000 -0.80814150 -0.4467285 -0.0575592  0.6647487    6.319994
theta[3]    -0.2323758 0.28154095 0.5220417 -1.3042827500 -0.58801025 -0.2036370  0.1531850  0.6909437    3.438169
p[1]         1.0129331 1.17940556 1.9545832 -2.8189635000 -0.38049675  1.0939900  2.4099475  4.6850590    2.746515
p[2]         6.2369546 3.09811775 3.9242490 -0.6909558750  1.65356500  8.0895250  9.3052600 10.9257425    1.604417
p[3]        10.6907368 0.02252125 1.2824686  8.2127330000  9.77637500 10.6847000 11.5349000 13.2768075 3242.713004
theta_n[1]   0.5680305 0.15664537 0.2026572  0.1312347750  0.39535300  0.6814930  0.7120465  0.7531341    1.673744
theta_n[2]   0.1977537 0.07119998 0.1151653  0.0603734925  0.11565975  0.1628290  0.2424110  0.4821426    2.616277
theta_n[3]   0.2342157 0.08849630 0.1151140  0.0691795750  0.14296950  0.1962120  0.3683120  0.4229827    1.692023
sigma[1,1]   3.4256922 1.50596224 1.8525152  0.5825078750  0.92417625  4.5851950  4.7853125  5.1148943    1.513197
sigma[1,2]   3.5140795 1.45370440 1.7889001  0.7225185750  1.10882500  4.6314350  4.8275225  5.1692545    1.514328
sigma[2,1]   0.7824114 0.14611578 0.2813299  0.2090872750  0.57900825  0.8062010  0.9796770  1.2836407    3.707128
sigma[2,2]   1.0090473 0.01169272 0.2041508  0.5988232250  0.88697600  0.9999235  1.1324675  1.4275905  304.839400
sigma[3,1]   1.5538882 0.78269777 0.9714749  0.5572269000  0.81358025  0.9933925  2.7960625  3.1485080    1.540547
sigma[3,2]   1.5859154 0.81938585 1.0188719  0.2968618000  0.84501925  1.0123400  2.8822900  3.2524710    1.546188
                 Rhat
mu_ord[1,1]  3.708499
mu_ord[1,2]  3.610858
mu_ord[2,1] 16.047415
mu_ord[2,2] 20.755349
mu_ord[3,1]  1.406798
mu_ord[3,2]  1.760585
theta[1]     1.849773
theta[2]     1.153572
theta[3]     1.317029
p[1]         1.438734
p[2]         3.619240
p[3]         1.000188
theta_n[1]   3.000197
theta_n[2]   1.512294
theta_n[3]   2.820832
sigma[1,1]  10.255679
sigma[1,2]   9.844162
sigma[2,1]   1.295773
sigma[2,2]   1.022352
sigma[3,1]   5.794911
sigma[3,2]   5.487649

I didn’t set a seed (oops) but here’s what I get no divergences and no warnings. Since your simulation assumed the means were >= 0 I added positive_ordered constraints. I also removed the softmax and just used a simplex putting a dirichlet prior on theta.

data {
 int N; //number of observations
 int K; //number of classes
 int NY; //number of observed variables
 vector[NY] Y[N]; //data
}
parameters {
 simplex[K] theta; //mixing proportions
 positive_ordered[K] mu[NY]; //mixture component means
 cholesky_factor_corr[NY] L[K]; //cholesky factor of covariance
 vector<lower=0>[NY] sigma[K];
 
 positive_ordered[K] p;
}
transformed parameters {
  vector[NY] mu_ord[K];
  for (j in 1:NY) mu_ord[:, j] = to_array_1d(mu[j, :]);
}
model {
  p ~ normal(0,10);
  theta ~ dirichlet([1,1,1]');
  
  for(k in 1:K){
    mu[, k] ~ normal(p[k], 2);
    L[k] ~ lkj_corr_cholesky(5);
    sigma[k] ~ std_normal();
  }

 {
    vector[K] ps;
    for (n in 1:N){
      for (k in 1:K){
        ps[k] = log(theta[k]) + multi_normal_cholesky_lpdf(Y[n] | to_vector(mu_ord[k]), diag_pre_multiply(sigma[k], L[k])); //increment log probability of the gaussian
      }
      target += log_sum_exp(ps);
    }
  }
}
$theta
$theta$pi1
[1] 0.2375372 0.2119226 0.5505402

library(rstan)
options(mc.cores = parallel::detectCores())
mod <- stan_model(file = "mix_forum_help.stan", model_name = "mixture")

data <- list(
  N = 300,
  K = 3,
  NY = 2,
  Y = mx_dt$y
)

f <- sampling(mod,
              chains = 2,
              control = list(adapt_delta = 0.9, max_treedepth = 10),
              iter = 600,
              data = data)
Inference for Stan model: mixture.
2 chains, each with iter=600; warmup=300; thin=1; 
post-warmup draws per chain=300, total post-warmup draws=600.

                mean se_mean   sd     2.5%      25%      50%      75%    97.5% n_eff Rhat
theta[1]        0.23    0.00 0.03     0.18     0.21     0.23     0.25     0.28  1366 1.00
theta[2]        0.23    0.00 0.03     0.18     0.21     0.23     0.25     0.28  1260 1.00
theta[3]        0.54    0.00 0.03     0.48     0.52     0.54     0.56     0.60  1327 1.00
mu[1,1]         0.14    0.00 0.08     0.01     0.07     0.13     0.20     0.31   567 1.00
mu[1,2]         9.83    0.00 0.14     9.54     9.73     9.84     9.92    10.08   843 1.00
mu[1,3]        19.99    0.00 0.08    19.83    19.93    19.98    20.04    20.15   656 1.00
mu[2,1]         0.08    0.00 0.06     0.00     0.03     0.06     0.12     0.24   820 1.01
mu[2,2]        10.09    0.01 0.14     9.80    10.00    10.09    10.17    10.38   476 1.00
mu[2,3]        19.90    0.00 0.07    19.77    19.85    19.90    19.94    20.05   650 1.00
L[1,1,1]        1.00     NaN 0.00     1.00     1.00     1.00     1.00     1.00   NaN  NaN
L[1,1,2]        0.00     NaN 0.00     0.00     0.00     0.00     0.00     0.00   NaN  NaN
L[1,2,1]        0.06    0.00 0.12    -0.18    -0.01     0.07     0.14     0.30   927 1.00
L[1,2,2]        0.99    0.00 0.01     0.96     0.99     1.00     1.00     1.00   375 1.00
L[2,1,1]        1.00     NaN 0.00     1.00     1.00     1.00     1.00     1.00   NaN  NaN
L[2,1,2]        0.00     NaN 0.00     0.00     0.00     0.00     0.00     0.00   NaN  NaN
L[2,2,1]        0.15    0.00 0.12    -0.07     0.06     0.15     0.23     0.38   840 1.00
L[2,2,2]        0.98    0.00 0.02     0.92     0.97     0.99     1.00     1.00   525 1.00
L[3,1,1]        1.00     NaN 0.00     1.00     1.00     1.00     1.00     1.00   NaN  NaN
L[3,1,2]        0.00     NaN 0.00     0.00     0.00     0.00     0.00     0.00   NaN  NaN
L[3,2,1]        0.02    0.00 0.08    -0.14    -0.03     0.03     0.07     0.18  1053 1.00
L[3,2,2]        1.00    0.00 0.00     0.98     1.00     1.00     1.00     1.00   291 1.00
sigma[1,1]      0.85    0.00 0.08     0.71     0.79     0.84     0.89     1.04  1037 1.00
sigma[1,2]      0.99    0.00 0.08     0.84     0.93     0.99     1.05     1.16  1057 1.00
sigma[2,1]      1.16    0.00 0.10     0.97     1.08     1.15     1.22     1.38  1431 1.00
sigma[2,2]      1.13    0.00 0.10     0.95     1.06     1.13     1.20     1.35   697 1.01
sigma[3,1]      1.06    0.00 0.06     0.95     1.02     1.06     1.10     1.18   773 1.00
sigma[3,2]      0.95    0.00 0.06     0.84     0.91     0.95     0.99     1.06   963 1.00
p[1]            1.19    0.03 0.87     0.05     0.49     1.04     1.68     3.22   888 1.00
p[2]            9.90    0.05 1.40     7.24     8.92     9.92    10.88    12.68   772 1.00
p[3]           19.56    0.05 1.35    16.83    18.63    19.58    20.45    22.13   837 1.00
mu_ord[1,1]     0.14    0.00 0.08     0.01     0.07     0.13     0.20     0.31   567 1.00
mu_ord[1,2]     0.08    0.00 0.06     0.00     0.03     0.06     0.12     0.24   820 1.01
mu_ord[2,1]     9.83    0.00 0.14     9.54     9.73     9.84     9.92    10.08   843 1.00
mu_ord[2,2]    10.09    0.01 0.14     9.80    10.00    10.09    10.17    10.38   476 1.00
mu_ord[3,1]    19.99    0.00 0.08    19.83    19.93    19.98    20.04    20.15   656 1.00
mu_ord[3,2]    19.90    0.00 0.07    19.77    19.85    19.90    19.94    20.05   650 1.00
lp__        -1162.25    0.23 3.27 -1169.50 -1164.20 -1161.92 -1160.08 -1156.78   206 1.00

Samples were drawn using NUTS(diag_e) at Mon Aug 31 10:16:59 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
3 Likes

Thank you very much for helping. After running the code multiple times, I am having difficulty replicating the results with three chains but everything works fine with two chains. For example,

Inference for Stan model: MVN_Mixture-202008311844-1-8447ff.
3 chains, each with iter=1000; warmup=500; thin=1; 
post-warmup draws per chain=500, total post-warmup draws=1500.

             mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff  Rhat
mu_ord[1,1]  0.08    0.00 0.06  0.01  0.04  0.07  0.12  0.21  1299  1.00
mu_ord[1,2]  0.06    0.00 0.04  0.00  0.02  0.05  0.08  0.16  1900  1.00
mu_ord[2,1]  6.45    3.65 4.47  0.06  0.17  9.49  9.65  9.92     2 33.32
mu_ord[2,2]  6.68    3.80 4.66  0.03  0.13  9.84 10.03 10.30     2 32.00
mu_ord[3,1] 20.20    0.00 0.12 19.96 20.12 20.20 20.28 20.44  2024  1.00
mu_ord[3,2] 20.35    0.00 0.13 20.10 20.26 20.36 20.44 20.62  1850  1.00
theta[1]     0.55    0.23 0.29  0.11  0.16  0.73  0.76  0.79     2 12.78
theta[2]     0.30    0.23 0.29  0.07  0.09  0.11  0.68  0.75     2 14.00
theta[3]     0.15    0.00 0.02  0.12  0.14  0.15  0.17  0.19  1643  1.00
p[1]         0.99    0.11 0.84  0.02  0.33  0.78  1.43  3.08    57  1.03
p[2]         6.94    3.10 3.95  0.50  2.15  8.68 10.02 12.18     2  3.44
p[3]        19.83    0.03 1.42 16.94 18.85 19.89 20.81 22.51  1830  1.00
sigma[1,1]   2.38    1.57 1.93  0.94  1.00  1.05  4.88  5.52     2 12.19
sigma[1,2]   2.43    1.66 2.04  0.92  0.99  1.03  5.08  5.78     2 12.40
sigma[2,1]   0.99    0.00 0.12  0.77  0.92  0.99  1.05  1.27  1476  1.00
sigma[2,2]   1.02    0.03 0.13  0.82  0.94  0.99  1.08  1.34    14  1.07
sigma[3,1]   0.84    0.00 0.09  0.68  0.77  0.83  0.90  1.03  1764  1.00
sigma[3,2]   0.87    0.00 0.09  0.71  0.80  0.86  0.93  1.08  1407  1.00


> mx_dt$theta
$pi1
[1] 0.7612361 0.1023583 0.1364056

$mu
     [,1] [,2]
[1,]    0    0
[2,]   10   10
[3,]   20   20
1 Like

You may have seen this already but check it out if you haven’t. https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html

The issue is that the 2nd mu is degenerate and the sampler doesn’t know if it should be 0 or 10. Try a stronger prior on p that is different for all k’s. Something like 0, 5, 15 mean normal on each p[k] with a smaller sd.

I wasn’t satisfied with that because you often don’t have a clue where the prior should be and who wants to spend their time fiddling with it.

So, I looked around and found this paper, “On a Class of Repulsive Mixture Models” which was published in Test (An Official Journal of the Spanish Society of Statistics and Operations Research), July 22, 2020. They do something pretty interesting and I happened to be able to implement it in Stan. We can keep the same model but add an additional “repulsive” function to repell means that are “close” to each other.

We need a function that maps 2 vectors (or points) into (0,1]. Then we multiply our target density by that number. The modifications necessary are:

functions {
  real potential(real x, real y) {
    return (1 - exp(-squared_distance(x, y)));
  }
...
model {
...
 target += log_sum_exp(ps + log(potential(p[1], p[2]) *  potential(p[2], p[3]) * potential(p[1], p[3])));
}

Output using 3 chains converges and Rhats are appx 1.

Inference for Stan model: mixture.
3 chains, each with iter=600; warmup=300; thin=1; 
post-warmup draws per chain=300, total post-warmup draws=900.

mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
theta[1]       0.71    0.00 0.02    0.66    0.69    0.71    0.72    0.75  1458 1.00
theta[2]       0.13    0.00 0.02    0.10    0.11    0.13    0.14    0.17  1103 1.00
theta[3]       0.16    0.00 0.02    0.13    0.15    0.16    0.18    0.21  1489 1.00
mu[1,1]        0.10    0.00 0.06    0.01    0.05    0.10    0.14    0.23   730 1.00
mu[1,2]       10.05    0.01 0.17    9.71    9.94   10.05   10.17   10.38   951 1.00
mu[1,3]       19.89    0.00 0.16   19.59   19.78   19.89   19.99   20.21  1002 1.00
mu[2,1]        0.09    0.00 0.05    0.01    0.05    0.08    0.13    0.20  1022 1.00
mu[2,2]        9.88    0.01 0.15    9.60    9.77    9.88    9.99   10.20   791 1.00
mu[2,3]       19.84    0.00 0.15   19.56   19.75   19.84   19.93   20.15  1112 1.00
L[1,1,1]       1.00     NaN 0.00    1.00    1.00    1.00    1.00    1.00   NaN  NaN
L[1,1,2]       0.00     NaN 0.00    0.00    0.00    0.00    0.00    0.00   NaN  NaN
L[1,2,1]      -0.04    0.00 0.06   -0.16   -0.08   -0.04    0.01    0.09  1215 1.00
L[1,2,2]       1.00    0.00 0.00    0.99    1.00    1.00    1.00    1.00   654 1.00
L[2,1,1]       1.00     NaN 0.00    1.00    1.00    1.00    1.00    1.00   NaN  NaN
L[2,1,2]       0.00     NaN 0.00    0.00    0.00    0.00    0.00    0.00   NaN  NaN
L[2,2,1]       0.11    0.01 0.14   -0.15    0.01    0.11    0.21    0.37   725 1.01
L[2,2,2]       0.98    0.00 0.02    0.93    0.98    0.99    1.00    1.00   587 1.00
L[3,1,1]       1.00     NaN 0.00    1.00    1.00    1.00    1.00    1.00   NaN  NaN
L[3,1,2]       0.00     NaN 0.00    0.00    0.00    0.00    0.00    0.00   NaN  NaN
L[3,2,1]      -0.01    0.00 0.13   -0.26   -0.10   -0.02    0.07    0.23  1233 1.00
L[3,2,2]       0.99    0.00 0.01    0.96    0.99    1.00    1.00    1.00   491 1.01
sigma[1,1]     1.08    0.00 0.05    0.99    1.05    1.08    1.12    1.19  1123 1.00
sigma[1,2]     0.98    0.00 0.05    0.89    0.95    0.98    1.02    1.09   905 1.00
sigma[2,1]     1.03    0.00 0.12    0.82    0.94    1.02    1.10    1.29  1070 1.00
sigma[2,2]     1.02    0.00 0.12    0.82    0.93    1.01    1.09    1.29   987 1.00
sigma[3,1]     1.05    0.00 0.11    0.86    0.97    1.05    1.12    1.29  1323 1.00
sigma[3,2]     1.04    0.00 0.11    0.85    0.97    1.04    1.10    1.28  1469 1.00
p[1]           1.18    0.03 0.86    0.05    0.50    0.99    1.67    3.27   697 1.00
p[2]           9.79    0.04 1.33    7.37    8.87    9.78   10.71   12.46   914 1.00
p[3]          19.50    0.04 1.32   16.78   18.63   19.52   20.35   22.03   864 1.00
mu_ord[1,1]    0.10    0.00 0.06    0.01    0.05    0.10    0.14    0.23   730 1.00
mu_ord[1,2]    0.09    0.00 0.05    0.01    0.05    0.08    0.13    0.20  1022 1.00
mu_ord[2,1]   10.05    0.01 0.17    9.71    9.94   10.05   10.17   10.38   951 1.00
mu_ord[2,2]    9.88    0.01 0.15    9.60    9.77    9.88    9.99   10.20   791 1.00
mu_ord[3,1]   19.89    0.00 0.16   19.59   19.78   19.89   19.99   20.21  1002 1.00
mu_ord[3,2]   19.84    0.00 0.15   19.56   19.75   19.84   19.93   20.15  1112 1.00
lp__        -782.75    0.18 3.33 -790.37 -784.72 -782.36 -780.50 -777.19   351 1.00

Samples were drawn using NUTS(diag_e) at Tue Sep  1 10:28:25 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
                                                                  convergence, Rhat=1).

Full stan code

functions {
  real potential(real x, real y) {
    return (1 - exp(-squared_distance(x, y)));
  }
}
data {
 int N; //number of observations
 int K; //number of classes
 int NY; //number of observed variables
 vector[NY] Y[N]; //data
}
parameters {
 simplex[K] theta; //mixing proportions
 positive_ordered[K] mu[NY]; //mixture component means
 cholesky_factor_corr[NY] L[K]; //cholesky factor of covariance
 vector<lower=0>[NY] sigma[K];
 
 positive_ordered[K] p;
}
transformed parameters {
  vector[NY] mu_ord[K];
  for (j in 1:NY) mu_ord[:, j] = to_array_1d(mu[j, :]);
}
model {
  p ~ normal(0,10);
  theta ~ dirichlet([1,1,1]');
 
  for(k in 1:K){
    mu[, k] ~ normal(p[k], 2);
    L[k] ~ lkj_corr_cholesky(5);
    sigma[k] ~ std_normal();
  }
 {
    vector[K] ps;
    for (n in 1:N){
      for (k in 1:K){
        ps[k] = log(theta[k]) + multi_normal_cholesky_lpdf(Y[n] | to_vector(mu_ord[k]), diag_pre_multiply(sigma[k], L[k])); //increment log probability of the gaussian
      }
      target += log_sum_exp(ps + log(potential(p[1], p[2]) *  potential(p[2], p[3]) * potential(p[1], p[3])));
    }
  }
 
}

Anyway, seems pretty promising for mixture models I wonder if @betanalpha or any of the standevs (@Bob_Carpenter) have any thoughts?

2 Likes

Furthermore, if I take the case study in that link and modify it to

functions {
  real potential_lpdf(vector y, vector a, real s) {
    real f = normal_lpdf(y | a, s);
    real r = log(1 - exp(-0.5 * sqrt(dot_self(a))));
   
    return f + r;
  }
}
data {
 int<lower = 0> N;
 vector[N] y;
}

parameters {
  ordered[2] mu;
  real<lower=0> sigma[2];
  vector[2] mu_hyper;
  real<lower=0, upper=1> theta;
}

model {
 sigma ~ normal(0, 2);
 mu_hyper ~ normal(0, 5);
 mu ~ potential(mu_hyper, 2);
 theta ~ beta(5, 5);

 for (n in 1:N)
   target += log_mix(theta,
                     normal_lpdf(y[n] | mu[1], sigma[1]),
                     normal_lpdf(y[n] | mu[2], sigma[2]));
 }
library(rstan)
options(mc.cores = parallel::detectCores())
N <- 1000
mu <- c(-0.75, 0.75);
sigma <- c(1, 1);
lambda <- 0.4
z <- rbinom(N, 1, lambda) + 1;
y <- rnorm(N, mu[z], sigma[z]);

#rstan_options(auto_write = TRUE)

stan_rdump(c("N", "y"), file="mix.data.R")

input_data <- read_rdump("mix.data.R")

singular_fit <- stan(file='..\\mix_betancourt.stan', data=input_data,
                     chains = 1, iter = 11000, warmup = 1000, 
                     seed = 483892929, refresh = 2000)

singular_fit

I get good neff

Inference for Stan model: mix_betancourt.
3 chains, each with iter=600; warmup=200; thin=1; 
post-warmup draws per chain=400, total post-warmup draws=1200.

                mean se_mean   sd     2.5%      25%      50%      75%    97.5% n_eff Rhat
mu[1]          -0.76    0.03 0.32    -1.46    -0.98    -0.75    -0.53    -0.19   102 1.01
mu[2]           0.35    0.02 0.30    -0.14     0.12     0.33     0.55     1.00   152 1.01
sigma[1]        1.09    0.01 0.13     0.82     1.00     1.09     1.18     1.34   111 1.01
sigma[2]        1.19    0.01 0.11     0.96     1.12     1.19     1.27     1.39   158 1.01
mu_hyper[1]    -0.19    0.04 1.05    -2.14    -0.96    -0.23     0.59     1.82   675 1.01
mu_hyper[2]     0.09    0.05 1.02    -1.80    -0.71     0.07     0.90     1.96   489 1.00
theta           0.48    0.01 0.17     0.17     0.36     0.48     0.60     0.78   342 1.00
lp__        -1688.17    0.18 2.33 -1693.63 -1689.48 -1687.76 -1686.46 -1684.95   164 1.01

Samples were drawn using NUTS(diag_e) at Tue Sep 01 19:11:18 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

The half-bow is still there

Compared to the case study fit with the same call to R

Inference for Stan model: mix_betancourt_case.
3 chains, each with iter=600; warmup=200; thin=1; 
post-warmup draws per chain=400, total post-warmup draws=1200.

             mean se_mean   sd     2.5%      25%      50%      75%    97.5% n_eff Rhat
mu[1]       -0.94    0.05 0.32    -1.49    -1.16    -0.97    -0.73    -0.26    48 1.07
mu[2]        0.47    0.02 0.24    -0.03     0.32     0.48     0.62     0.91   149 1.00
sigma[1]     0.98    0.02 0.14     0.73     0.88     0.96     1.07     1.25    45 1.07
sigma[2]     1.02    0.01 0.11     0.85     0.95     1.01     1.09     1.26   124 1.01
theta        0.45    0.01 0.14     0.20     0.35     0.45     0.54     0.74   157 1.04
lp__     -1637.22    0.22 2.12 -1642.02 -1638.66 -1636.65 -1635.58 -1634.36    90 1.04

Samples were drawn using NUTS(diag_e) at Tue Sep 01 16:42:35 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

I’ve been playing around more and it’s interesting. If I do this in the case study

data {
 int<lower = 0> N;
 vector[N] y;
}

parameters {
  ordered[2] mu;
  real<lower=0> sigma[2];
  real<lower=0> tau;
  real<lower=0, upper=1> theta;
}

model {
 sigma ~ normal(0, 2);
 //mu_hyper ~ std_normal();
 mu ~ normal(0, 2);
 theta ~ beta(5, 5);
 tau ~ gamma(2, 2);

 for (n in 1:N)
   target += log_mix(theta,
                     normal_lpdf(y[n] | mu[1], sigma[1]) + log(1 - exp(- squared_distance(mu[1], mu[2]) * tau)),
                     normal_lpdf(y[n] | mu[2], sigma[2]) + log(1 - exp(- squared_distance(mu[1], mu[2]) * tau)));
 }

with output (no divergences or lowbmfi)

Inference for Stan model: mix_betancourt.
3 chains, each with iter=600; warmup=200; thin=1; 
post-warmup draws per chain=400, total post-warmup draws=1200.

             mean se_mean   sd     2.5%      25%      50%      75%    97.5% n_eff Rhat
mu[1]       -0.83    0.02 0.33    -1.58    -1.03    -0.79    -0.57    -0.34   268 1.02
mu[2]        0.69    0.02 0.37     0.07     0.42     0.66     0.93     1.46   226 1.03
sigma[1]     1.07    0.00 0.08     0.89     1.02     1.08     1.13     1.22   396 1.01
sigma[2]     1.02    0.01 0.11     0.77     0.96     1.03     1.09     1.19   286 1.01
tau          3.57    0.04 0.92     2.12     2.90     3.49     4.09     5.70   521 1.00
theta        0.55    0.01 0.21     0.15     0.39     0.55     0.72     0.89   246 1.03
lp__     -1669.33    0.09 1.76 -1673.37 -1670.26 -1669.04 -1668.04 -1666.81   353 1.00

Samples were drawn using NUTS(diag_e) at Tue Sep 01 19:45:05 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

I do get a repulsive effect as seen

Thank you very much. This works very well

Repulsive priors can certainly help but you have to be careful with them especially if you want to scale to mixtures in higher dimensional spaces. If we require that the repulsion is exchangeable between the mixture components then the natural approach is known as a determinantal point proces prior, which essentially reduces to the log determinant of a Gaussian process-like covariance matrix. There is a reasonably large literature on these priors, although the computational studies are pretty weak overall. Note also that the repulsion also typically has to be complemented with some attractive prior to ensure that none of the mixture components fly off to infinity.

Although they are sensitive to the specific hyperparameter tuning, repulsive priors definitely prevent the collapse of multiple components on top of each other, and the resulting degeneracies, as you see. The problem, and the reason why I haven’t written about my experiments with them in any detail, is that they resolving the collapse problem only reveals another serious pathology.

Once the components are ordered and and a repulsive prior is added Gaussian mixture models will still be plagued by an inflation degeneracy where the variance of one, or more, components inflates so that the component densities cover multiple true components. The “unused” density between the true components offers little penalty to the inflation and, because of the repulsive prior, the unused components can’t force their way closer to the inflated component and push it away from both true components. This occurs even when using the true number of components!

To demonstrate consider the following example. First we generate data from four true normal components.

library(MASS)
library(rstan)

mu <- c(-5, -1, 4, 7)
sigma <- c(1, 0.5, 0.25, 0.3);

lambdas <- c(0.2, 0.4, 0.3, 0.1)
K <- length(lambdas)

set.seed(689934)

N <- 1000

z <- rmultinom(N, 1, lambdas);
z <- sapply(1:N, function(n) which(z != 0, arr.ind = T)[[n]])
y <- t(sapply(1:N, function(n) rnorm(1, mu[z[n]], sigma[z[n]])))[1,]

stan_rdump(c("K", "N", "y"), file="mix.data.R")

Fitting with a naive Gaussian mixture model demonstrates label switching.

data {
  int<lower=1> K;
  int<lower=1> N;
  real y[N];
}

parameters {
  real mu[K];
  real<lower=0> sigma[K];
  simplex[K] lambda;
}

model {
  // Prior model
  mu ~ normal(0, 5);
  sigma ~ normal(0, 1);
  lambda ~ dirichlet(rep_vector(3, K));
  
  // Observational model
  for (n in 1:N) {
    real comp_lpdf[K];
    for (k in 1:K) {
      comp_lpdf[k] = log(lambda[k]) + normal_lpdf(y[n] | mu[k], sigma[k]);
    }
    target += log_sum_exp(comp_lpdf);  
  }
}

Ordering the component locations resolve the label switching but then reveals the collapse problem.

data {
  int<lower=1> K;
  int<lower=1> N;
  real y[N];
}

parameters {
  ordered[K] mu;
  real<lower=0> sigma[K];
  simplex[K] lambda;
}

model {
  // Prior model
  mu ~ normal(0, 5);
  sigma ~ normal(0, 1);
  lambda ~ dirichlet(rep_vector(3, K));
  
  // Observational model
  for (n in 1:N) {
    real comp_lpdf[K];
    for (k in 1:K) {
      comp_lpdf[k] = log(lambda[k]) + normal_lpdf(y[n] | mu[k], sigma[k]);
    }
    target += log_sum_exp(comp_lpdf);  
  }
}

Finally adding a repulsive prior – in this case tuned to the true component behaviors! – resolves the collapse but manifests the inflation problems.

functions {
  real repulsive_lpdf(vector mu, real rho) {
    int K = num_elements(mu);
    matrix[K, K] S;
    matrix[K, K] L;
    real log_det = 0;

    for (k1 in 1:K)
      for (k2 in 1:K)
        S[k1, k2] = exp(- square(mu[k1] - mu[k2]) / square(rho));
    L = cholesky_decompose(S);

    for (k in 1:K)
      log_det = log_det + 2 * log(L[k, k]);

    return log_det;
  }
}

data {
  int<lower=1> K;
  int<lower=1> N;
  real y[N];
}

parameters {
  ordered[K] mu;
  real<lower=0> sigma[K];
  simplex[K] lambda;
}

model {
  // Prior model
  mu ~ normal(0, 5);
  mu ~ repulsive(1);
  sigma ~ normal(0, 1);
  lambda ~ dirichlet(rep_vector(3, K));
  
  // Observational model
  for (n in 1:N) {
    real comp_lpdf[K];
    for (k in 1:K) {
      comp_lpdf[k] = log(lambda[k]) + normal_lpdf(y[n] | mu[k], sigma[k]);
    }
    target += log_sum_exp(comp_lpdf);  
  }
}

I don’t have any plots handy but scatter plots of various combinations of the component means and scales separated by Markov chain will display the problems clearly. You should also be able to see instability issues with regard to the seed and hence initial values for the Markov chains.

When only using two components this problem may not be as serious but I’ve never been able to shake it with the four component example. Combined with the fact that clustering is well known to be extremely degenerate for any finite data set I haven’t spent too much time trying to find exceptional niches where things might work okay.

I don’t agree with Larry Wasserman on much but the fundamental pathologies of Gaussian mixture models, https://normaldeviate.wordpress.com/2012/08/04/mixture-models-the-twilight-zone-of-statistics/, is one of them!

3 Likes

Thanks for the detailed response!

I was able to get the mixture to work without the overinflation by changing the code a bit. You’ll see I changed the specification of the repulsive distribution. I’m only adding the pairwise comparisons to lp. That means I have to build the full matrix for S. I’ve let rho vary by each pairwise comparison but it seems to like 0.5 for all the pairs.

But this breaksdown if I run 3 chains. I’m guessing this is the sensitivity to initial values.

Full code and results for 1 chain looks good:

functions {
  real repulsive_lpdf(vector mu, vector rho) {
    int K = num_elements(mu);
    matrix[K, K] S = diag_matrix(rep_vector(1, K));
    matrix[K, K] L;
    int c = 0;

    for (k1 in 1:(K - 1))
      for (k2 in (k1 + 1):K){
        c += 1;
        S[k1, k2] = log(1 - exp(- squared_distance(mu[k1], mu[k2]) / rho[c]));
        S[k2, k1] = S[k1, k2];
      }
    L = cholesky_decompose(S);

    return 2 * sum(log(diagonal(L)));
  }
}

data {
  int<lower=1> K;
  int<lower=1> N;
  real y[N];
}

parameters {
  ordered[K] mu;
  vector<lower=0>[choose(K, 2)] rho;
  real<lower=0> sigma[K];
  simplex[K] lambda;
}

model {
  // Prior model
  mu ~ normal(0, 5);
  sigma ~ std_normal();
  lambda ~ dirichlet(rep_vector(3, K));
  rho ~ gamma(1, 2);
  mu ~ repulsive(rho);
  
  // Observational model
  for (n in 1:N) {
    real comp_lpdf[K];
    for (k in 1:K) {
      comp_lpdf[k] = log(lambda[k]) + normal_lpdf(y[n] | mu[k], sigma[k]);
    }
    target += log_sum_exp(comp_lpdf);  
  }
}

Inference for Stan model: beta_alpha_sean.
1 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=1000.

              mean se_mean   sd     2.5%      25%      50%      75%    97.5% n_eff Rhat
mu[1]        -5.05    0.00 0.07    -5.17    -5.09    -5.05    -5.00    -4.91   643    1
mu[2]        -0.98    0.00 0.03    -1.04    -1.00    -0.98    -0.97    -0.93  1489    1
mu[3]         4.04    0.00 0.01     4.01     4.03     4.04     4.05     4.06   999    1
mu[4]         6.99    0.00 0.03     6.93     6.97     6.99     7.01     7.05  1197    1
rho[1]        0.49    0.01 0.49     0.01     0.13     0.34     0.71     1.81  1449    1
rho[2]        0.49    0.01 0.46     0.01     0.15     0.35     0.68     1.66  1519    1
rho[3]        0.51    0.01 0.51     0.01     0.14     0.36     0.73     1.94  1437    1
rho[4]        0.48    0.01 0.48     0.01     0.13     0.31     0.67     1.77  1778    1
rho[5]        0.49    0.01 0.49     0.02     0.14     0.33     0.71     1.71  1711    1
rho[6]        0.51    0.01 0.50     0.01     0.15     0.36     0.70     1.94  1387    1
sigma[1]      0.95    0.00 0.05     0.85     0.91     0.94     0.98     1.05  1588    1
sigma[2]      0.55    0.00 0.02     0.51     0.53     0.54     0.56     0.59  1125    1
sigma[3]      0.24    0.00 0.01     0.22     0.24     0.24     0.25     0.26  1380    1
sigma[4]      0.32    0.00 0.02     0.27     0.30     0.32     0.33     0.37  1394    1
lambda[1]     0.18    0.00 0.01     0.16     0.18     0.18     0.19     0.21  1313    1
lambda[2]     0.40    0.00 0.02     0.37     0.38     0.40     0.41     0.43  1524    1
lambda[3]     0.32    0.00 0.01     0.29     0.31     0.32     0.33     0.35  1653    1
lambda[4]     0.10    0.00 0.01     0.08     0.09     0.10     0.11     0.12  1520    1
lp__      -1899.03    0.17 3.04 -1905.97 -1900.81 -1898.75 -1896.88 -1893.98   314    1

Samples were drawn using NUTS(diag_e) at Mon Sep  7 08:00:32 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
singular_fit <- stan(file='beta_alpha_sean.stan', data=input_data,
                     chains = 1, iter = 2000, warmup = 1000, 
                     seed = 483892929, refresh = 500)
singular_fit

Here’s the plot


params1 <- as.data.frame(extract(singular_fit, permuted=FALSE)[,1,])

c_light_trans <- c("#DCBCBCBF")
c_light_highlight_trans <- c("#C79999BF")
c_mid_trans <- c("#B97C7CBF")
c_mid_highlight_trans <- c("#A25050BF")
c_dark_trans <- c("#8F2727BF")
c_dark_highlight_trans <- c("#7C0000BF")

par(mar = c(4, 4, 0.5, 0.5))
plot(params1$"mu[1]", params1$"mu[2]", col=c_dark_highlight_trans, pch=16, cex=0.8,
     xlab="pair_1", xlim=c(-6, 8), ylab="pair_2", ylim=c(-2, 8))
points(params1$"mu[1]", params1$"mu[3]", col="blue", pch=16, cex=0.8)

points(params1$"mu[1]", params1$"mu[4]", col="green", pch=16, cex=0.8)
points(params1$"mu[2]", params1$"mu[3]", col="orange", pch=16, cex=0.8)
points(params1$"mu[2]", params1$"mu[4]", col="purple", pch=16, cex=0.8)
points(params1$"mu[3]", params1$"mu[4]", col=c_light_trans, pch=16, cex=0.8)

Adding 3 chains and you’ll see the separation break down in chains 2 and 3


1 Like

I’m mildly apprehensive about the extra log, but the form of log(1 + exp(- distance)) does have the advantage of saturating at larger distances which might explain why it avoids pushing the components apart too much relative to the typical repulsive priors that I’ve seen in the literature.

Really the overall challenge here is coming up with robust prior modeling methods that are easy to tune based only on only domain expertise, as opposed to tuning arbitrary until the fits look okay (which is how a lot of clustering methods are executed in practice). In other words can the rho be chosen based only on domain expertise? Are there multiple value of rho that yield fits without any failing diagnostics but different posterior behaviors, and if so what different assumptions do those different values of the rho encode?

The fact that you get different behavior with 3 chains suggests that there’s probably still some significant degeneracy going on, which isn’t surprisingly.