Computationally expensive three component mixture model, suggestions for speed-up?

I have written a model which describes a population of N stars with a three component mixture. The model runs fine, doesn’t show divergences and converges for small N (~10, converged after 2.5 hours), but takes prohibitively long to run on my full dataset (N=632, after four days hadn’t converged). I would be very appreciative of anyone who could look at the script attached and suggest anything (or point to a page in the manual) that could speed up the convergence!

I have re-parameterized every parameter that I could. I tried vectorising the transformation of variables, but I found that that didn’t at all change the run-time and made it less human-readable. I store the main likelihood calculation in a vector which I sum before adding it to target.

stars.stan (5.8 KB)

1 Like

It would be useful to know whether it takes forever because it’s using a small step-size, or because each step takes enormous amounts of computing time. What is the message that is spit out by Stan about run-time of a single gradient eval?

Also, what is the step-size and tree-depth Stan decides to use? You can find these things with get_sampler_params in rstan or by looking at the csv files that CmdStan spits out.

Thanks for looking at my problem!

CmdStan reports that the run-time of a single gradient eval is 0 seconds (see below).

Gradient evaluation took 0 seconds
1000 transitions using 10 leapfrog steps per transition would take 0 seconds.
Adjust your expectations accordingly!

Running CmdStan stansummary on one of my chains, it returns:

                        Mean     MCSE   StdDev        5%       50%       95%  N_Eff  N_Eff/s  R_hat
lp__                -5.7e+03  9.0e+00  6.7e+01  -5.8e+03  -5.7e+03  -5.6e+03     55  5.4e-04    1.0
accept_stat__        6.8e-01  1.3e-01  3.3e-01   9.6e-03   8.4e-01   9.9e-01    6.9  6.9e-05    1.3
stepsize__           1.3e-03  1.8e-16  1.3e-16   1.3e-03   1.3e-03   1.3e-03   0.50  5.0e-06   1.00
treedepth__          8.5e+00  1.0e+00  2.7e+00   2.0e+00   1.0e+01   1.0e+01    6.9  6.8e-05    1.2
n_leapfrog__         7.8e+02  1.5e+02  4.0e+02   7.0e+00   1.0e+03   1.0e+03    7.0  6.9e-05    1.2
divergent__          3.2e-01  1.6e-01  4.7e-01   0.0e+00   0.0e+00   1.0e+00    8.3  8.2e-05    1.2
energy__             8.5e+03  9.2e+00  8.5e+01   8.4e+03   8.5e+03   8.7e+03     87  8.6e-04    1.0

What does a traceplot of lp__ look like? I’m wondering if you’re having similar issues to what I was talking about in a separate thread, where the initial conditions are under-damped, and “ring” for a long time.

Also in your 10 star example problem, how tightly distributed were the posterior intervals of your parameters? I see your step-size is 1.3e-3 and you’re taking 780 leapfrog steps on average. But it doesn’t seem this should take that long given that it takes 0 seconds to do a gradient eval ;-) (what this means is, less than the clock resolution on your computer, which isn’t that helpful)

How many samples are you asking Stan for?

1 Like

I’ve attached a traceplot of lp__ below. I’m asking for 60000 warmup iterations and 40000 samples, but with thin=2. 20000 samples makes for a messy traceplot so I’ve applied a Savitsky-Golay filter to the lp___ values and also plotted that in green to make the behaviour a little clearer. What does ringing look like in a traceplot?

Here is the output of CmdStan stansummary for a chain from my 10 star example problem. The intervals don’t seem that wide.

                       Mean     MCSE   StdDev        5%       50%       95%  N_Eff  N_Eff/s  R_hat
lp__               -1.9e+01  1.6e-01  8.3e+00  -3.3e+01  -1.8e+01  -5.8e+00   2659  8.0e-01    1.0
accept_stat__       9.0e-01  1.7e-02  1.5e-01   6.1e-01   9.5e-01   1.0e+00     83  2.5e-02    1.0
stepsize__          2.6e-03  5.8e-16  4.1e-16   2.6e-03   2.6e-03   2.6e-03   0.50  1.5e-04   1.00
treedepth__         9.9e+00  4.4e-02  7.0e-01   1.0e+01   1.0e+01   1.0e+01    247  7.5e-02    1.0
n_leapfrog__        1.0e+03  1.1e+01  1.4e+02   1.0e+03   1.0e+03   1.0e+03    166  5.0e-02    1.0
divergent__         4.3e-02  1.8e-02  2.0e-01   0.0e+00   0.0e+00   0.0e+00    123  3.7e-02    1.0
energy__            6.9e+01  1.5e-01  1.1e+01   5.2e+01   6.9e+01   8.8e+01   5247  1.6e+00    1.0
fraction[1]         7.4e-01  2.8e-02  1.4e-01   4.8e-01   7.6e-01   9.4e-01     25  7.6e-03    1.0
fraction[2]         1.7e-01  2.7e-02  1.3e-01   1.6e-02   1.4e-01   4.3e-01     23  6.9e-03    1.0
fraction[3]         8.3e-02  8.3e-04  7.7e-02   4.5e-03   6.0e-02   2.4e-01   8462  2.6e+00    1.0
L                   2.4e-02  1.0e-04  4.6e-03   1.8e-02   2.4e-02   3.2e-02   1967  5.9e-01    1.0
vRdisp_raw          7.1e-01  1.6e-02  1.7e-01   4.7e-01   6.9e-01   1.0e+00    120  3.6e-02   1.00
vphidisp_raw        7.3e-01  1.5e-02  1.8e-01   4.7e-01   7.1e-01   1.1e+00    134  4.1e-02    1.0
vzdisp_raw          2.8e-01  7.4e-03  9.6e-02   1.5e-01   2.6e-01   4.6e-01    172  5.2e-02    1.0
vdisk_raw           1.3e-03  7.7e-03  1.0e+00  -1.7e+00   1.0e-02   1.6e+00  16964  5.1e+00   1.00
Rsolar_raw          1.3e-02  6.5e-03  9.9e-01  -1.6e+00   2.2e-02   1.6e+00  23476  7.1e+00    1.0
Usolar_raw          1.7e-01  1.1e-02  9.7e-01  -1.4e+00   1.7e-01   1.8e+00   8318  2.5e+00   1.00
Vsolar_raw         -2.1e-01  1.3e-02  9.1e-01  -1.7e+00  -2.2e-01   1.3e+00   5297  1.6e+00    1.0
Wsolar_raw          5.6e-01  1.2e-02  9.8e-01  -1.0e+00   5.6e-01   2.2e+00   7217  2.2e+00    1.0

Yikes!

So your average n_leapfrog is about 1000, which means that for each of your Stan samples, you’re calculating a 1000 point long HMC trajectory and then selecting one of those points. (as opposed to say JAGS or the like, where each sample corresponds to just one function eval).

So when you ask it for 100,000 samples, it’s calculating 100 MILLION points along HMC trajectories. The point of HMC is that it moves around really fast, and so you need fewer samples.

When it comes to HMC I basically never ask for anything like the number of samples you’re asking for. For complicated models I often work with 100 to 1000 points. For simple models I usually use the 2000 points Stan default. Maybe once I’m completely sure my model makes sense I’ll ask for a really long run, but even there, “really long” typically means way shorter than 100,000 samples.

Your inference seems good in the sense that Rhat = 1.0 for most parameters, the concern I have is that N_eff is small for vRdisp_raw fphidisp_raw etc even after 100k samples. But it’s not that small! 120 samples is giving you pretty good inference on that parameter.

The next concern is the “lp__ gets stuck” concern. This indicates that HMC got into a region of space that it couldn’t get out of. I’d be really concerned if the stuckness occurred well outside the main band of lp__ because that would indicate you hadn’t yet found the real typical set, but that’s not what happens. it gets stuck in certain corners of the typical set. I’d look at which of your parameters is getting stuck, and if the inference is sensible in the stuck region. If the stuck region doesn’t make sense, then perhaps your priors need modification to keep Stan out of those regions. If the stuck regions do make sense, then perhaps you need to investigate what’s going on in those regions model-wise and maybe even revise your model to allow for more scientific possibilities than you considered at first.

If you revise your model and it eliminates the stuckness on the 10 star problem, I’d try asking it to do the 600 star problem and ask for say 2000 samples with 1000 warmup, look at your Neff and Rhat, see if it makes sense at those much smaller sample numbers.

3 Likes

Also, all your parameters sd values are O(1) and your stepsize is O(1e-3) and your trajectory length is O(1000) so that stepsize*trajectory/sd is O(1).

Stan is using the full treedepth. I’d recommend potentially incrementing the treedepth to 11 or 12, and then running 2000 samples. This will be 4*2/100 = .08 as much computing time, but may give you better N_eff on the vRdisp vphidisp etc

1 Like

That is a very helpful response, thanks for taking the time!

I was using much smaller numbers of samples until I found the advice in the manual suggesting that you should check the convergence and if it hadn’t converged to double everything. I was a little concerned that everyone else on the forum seemed to be using ~1000 samples…

I’ll see if I can figure out where the HMC is getting stuck and try increasing the tree depth.

so as far as your lp plot goes, by the time you hit 5000 samples or so your plot looks pretty well converged. The best convergence diagnostic we have is a combination of the lp plot, and the Rhat calculation.

“ringing” in the lp plot looks like the lp shooting upwards, then overshooting and coming back downwards, and then upwards, and downwards, like an oscillation (a classic “mass on a spring” damped oscillation, a noisy version of this https://goo.gl/images/LUsS35 ). It’s hard to know given that the whole thing probably happens in the first 10% of your plot, but look for that. If that happens, you can eliminate it by getting better initialization points. One way to get good initialization points is to try samples from the vb system. Another is to run your model through stan for a small number of iterations with a short treedepth, to bleed out the excess potential energy quickly. Then take the final sample from this short run as input to a longer run.

1 Like

It’s not the number of draws in the sample that matters, it’s the effective sample size (reported as n_eff in most of the interfaces). MCMC standard error for estimating the mean of a parameter theta is sd(theta) / sqrt(n_eff), so you can use that as a guide as to how much precision you want.

How are you measuring convergence? There are issues with mixture models that make them hard to identify, so multiple chains can get stuck in different identical modes, or there may be more meaninful multimodality. Michael Betancourt wrote a case study (web site >> users >> documentation >> case studies) on mixture models that’s worth reading. One of the things it proposes is using an ordering constraint to identify mixture components.

The only other thing that’s going to be very slow in your model is

    parallax ~ normal(1. ./ r, errparallax);

because parallax is a parameter and so is r. You can replace this with a centered version

  parallax_std ~ normal(0, 1);
  parallax = parallax_std * errparallax + inv(r);

declaring prallax_std as a parameter and parallax as a transformed parameter.

It might help a bit to standardize other variables, writing

L_std ~ normal(0, 1);
L = 0.5 + 0.5 * L_std;

in the same way. That way, Stan starts off with the adaptation on the right scale.

1 Like

I managed to get much nicer behavior by sampling in vrad, pmra and pmdec instead of pecvR, pecvphi and pecvz. I think it makes sense that this is easier to sample this way round because the observations have very small errors while the model parameters have quite an uninformative prior. It also helps that writing it down that way round eliminates two-thirds of the matrix operations.

My original script ran for 4 days and hadn’t converged. My new script takes 20 minutes and has definitely converged. Thank you both for your help! Even though the “fix” ended up being re-writing the model, it was helpful to have my expectations of what a reasonable number of samples is for HMC reset.

2 Likes