Coping with varying curvature

I am fitting a rather complicated model but managed to narrow the issue to something that is easy to demonstrate in a small, self-contained example.

Consider a mixture of two normals in \mathbb{R}^3: one at the origin, with variances 1, one at [1,1,1], with a narrow variance. Here is the Stan code:

data {
  int<lower=1> K;               // dimension
  real<lower=0, upper=1> alpha; // mixture weight
  vector[K] mu1;
  vector[K] mu2;
  matrix[K,K] Sigma1;
  matrix[K,K] Sigma2;
}
parameters {
  vector[3] y;
}
model {
  target += log_sum_exp(log(alpha) + multi_normal_lpdf(y | mu1, Sigma1),
                        log(1-alpha) + multi_normal_lpdf(y | mu2, Sigma2));
}

Here is the parametrization in R:

library("rstan")
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

Sigma1 <- matrix(c(1, 0, 0, 0, 1, 0, 0, 0, 1), 3)
Sigma2 <- matrix(c(0.22251677640490497, -0.10171944195482832, 0.047549588109993136,
                   -0.10171944195482832, 0.16903679793887302, -0.06377848041623378,
                   0.047549588109993136, -0.06377848041623378, 0.08844642565622202), 3)
Sigma2_scale <- 1/16
data <- list(K = 3, alpha = 0.3,
             mu1 = c(0, 0, 0), Sigma1 = Sigma1,
             mu2 = c(1, 1, 1), Sigma2 = Sigma2 * Sigma2_scale)
fit <- stan(file = "normal-mixture.stan", data = data)
get_adaptation_info(fit)

If you make Sigma2_scale <- 1, it fits OK, but with the value above, I get

1: There were 8 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup 
2: Examine the pairs() plot to diagnose sampling problems
 
3: The largest R-hat is 1.29, indicating chains have not mixed.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#r-hat 
4: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#bulk-ess 
5: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#tail-ess 

I imagine that the issue is very different curvature at various parts of the posterior.

I fixed the original model by allowing more noise (I think it was misspecified), but I am curious if there is anything I can do to improve the MCMC performance of this model in Stan as it is.

Try adding the control(metric = "dense") option for this.

Thanks, but

fit <- stan(file = "normal-mixture.stan", data = data, control = list(metric = "dense_e"))

doesn’t help,

> fit
Inference for Stan model: normal-mixture.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

      mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
y[1]  0.24    0.22 0.98 -1.82 -0.42  0.37  0.99  1.94    19 1.08
y[2]  0.21    0.23 0.98 -1.90 -0.47  0.34  0.99  1.81    19 1.09
y[3]  0.22    0.24 0.97 -1.83 -0.47  0.36  0.99  1.85    16 1.09
lp__ -3.77    2.12 3.56 -8.48 -5.78 -4.81 -4.13  3.91     3 2.02

Samples were drawn using NUTS(dense_e) at Wed Aug 21 19:12:30 2019.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Maybe my understanding is incorrect, but I am under the impression that standard HMC cannot be adapted to these two curvatures at the same time, so this is not something that the dense metric helps with.

1 Like

Can you plot the three pairs plots for the three variables, ideally with divergences visualized?

Of course. As expected, the divergence is at the “spike” of the second distribution not it isn’t, I misread the plot. I don’t understand why the divergence is where it is.

I don’t see any divergences at all – you might seen some pop up if you run longer chains (although the pairs plots then become more computationally expensive).

The bigger issue here is that you’re trying to fit a multimodal model. Stan is doing an admirable job here but the effective sample size is determined not by how well the chains explore within each component but rather how well a single chain can transition between the two components. There’s not much you can do to improve that other than find a way to make your model not multimodal. Note that things will only get worse once you make more terms parameters.

You many also want to look at the step size of each chain separately. For example one chain might spend the majority of warmup in the wider component and adapt to a large step size that will cause divergences when it tries to transition into the narrow component later on.

1 Like

I was under the impression that plotting with

pairs(fit, condition = "divergent__")

(which is what I did) puts the divergences in the top half.

From the help for pairs.stanfit:

By default, the lower (upper) triangle of the plot contains draws with below (above) median acceptance probability. Also, if condition is not "divergent__" , red points will be superimposed onto the smoothed density plots indicating which (if any) iterations encountered a divergent transition. Otherwise, yellow points indicate a transition that hit the maximum treedepth rather than terminated its evolution normally. (emphasis added)

I didn’t see any red points in your pairs() plot, which means there aren’t any divergences.

Oh yeah, I didn’t read the model right. Definitely what @betanalpha said. Something that’s actually gonna be multimodal will be hard.

Where does this mixture show up in your model? Maybe there’s a trick.

Thanks, I fixed the actual model (as I said above, by adding more errors and relaxing the problem from misspecification), I was just curious about this.

1 Like