Divergent transitions when sampling from Dirichlet distribution

Hello all,

I have been trying to debug a model with lots of divergent transitions. The model includes a Dirichlet prior with most of the mass concentrated on one component. In the process of trying to isolate the source of divergent transitions in my model, I found that simply sampling from a Dirichlet distribution resulted in a relatively large number of divergent transitions, but also that this depended on how I ordered the prior.

If I put the parameter with the large concentration as the first parameter in the vector, I get 2-3% divergent transitions. On the other hand, if I put the parameter with the large concentration last, there are no divergent transitions:

> ## Sample from Dirichlet with alpha = c(20, 0.1, 0.1, ..., 0.1)
> t(sapply(get_sampler_params(fit1, inc_warmup=FALSE), colMeans))
     accept_stat__ stepsize__ treedepth__ n_leapfrog__ divergent__ energy__
[1,]     0.7990522  0.2981183       3.380       12.230       0.018 17.24487
[2,]     0.8768769  0.2446349       3.560       13.575       0.026 17.23890
[3,]     0.8386421  0.2954739       3.388       12.306       0.034 17.60014
[4,]     0.8900672  0.2370209       3.590       13.564       0.041 17.18930
>
> ## Sample from Dirichlet with alpha = c(0.1, 0.1, ..., 0.1, 20)
> t(sapply(get_sampler_params(fit2, inc_warmup=FALSE), colMeans))
     accept_stat__ stepsize__ treedepth__ n_leapfrog__ divergent__ energy__
[1,]     0.8415861 0.09923680       4.133       23.700           0 17.94246
[2,]     0.9099294 0.08892594       4.387       28.132           0 17.33387
[3,]     0.9150783 0.07699482       4.487       30.936           0 18.04562
[4,]     0.8939827 0.09872973       4.347       27.772           0 17.88692

I am wondering if this is ‘expected’ behaviour? I imagine that it may be related to the ‘stick breaking’ algorithm used for the simplex parameterization. Is an intuitive explanation for why it occurs, which might support guidance for most efficient parameterization for more complex models? I also note the caution on page 40 of the manual that high dimensional simplex parameters might require smaller step sizes. Any guidance on the most efficient parameterizations would be useful.

Full stan code for the above example pasted below.

Many thanks,
Jeff

library(rstan)

mod1 <- '
transformed data {
  vector[10] alpha;
  alpha[1] = 20;
  for(i in 2:10)
    alpha[i] = 0.1;
}
parameters {
  simplex[10] x;
}
model {
  x ~ dirichlet(alpha);
}
'

mod2 <- '
transformed data {
  vector[10] alpha;
  for(i in 1:9)
    alpha[i] = 0.1;
  alpha[10] = 20;
}
parameters {
  simplex[10] x;
}
model {
  x ~ dirichlet(alpha);
}
'

fit1 <- stan(model_code = mod1)
fit2 <- stan(model_code = mod2)


## Sampling from Dirichlet with alpha = c(20, 0.1, 0.1, ..., 0.1)
t(sapply(get_sampler_params(fit2, inc_warmup=FALSE), colMeans))

## Sampling from Dirichlet with alpha = c(0.1, 0.1, ..., 0.1, 20)
t(sapply(get_sampler_params(fit1, inc_warmup=FALSE), colMeans))

For this case, the parameterization of the simplex that rescales gamma(alpha, 1) primitives avoids the divergences problem and the ordering arbitrariness. See

For the case you are probably actually referring to (i.e. with data), it is hard to tell in advance which parameterization will work best.

2 Likes

Hi Ben,

Many thanks for the guidance again. Yep, for the full model with data both approaches – the rescaled gamma or re-ordering the simplex such that the parameter with most concentration is last – reduce the number of divergent transitions to very few. But both also result in adaptation to very small step size and slow sampling.

I guess no free lunch here. I’ll play around a bit more…

Thanks,
Jeff

1 Like

Not yet—still on the to-do list, along with vectorizing our derivatives for the transforms and allowing varying lower and upper bounds in an array.

That’s a quote of a quote, but the original quote doesn’t seem to be there any more!