Debugging a "stuck" sampler on a complex-ish model

I’m currently having some issues with Stan’s NUTS sampler. I’ve encountered them before, but last time I think I fixed the issue by either hardening the model against outputting NaNs or “fixed” it by fiddling with random seed values. This time around, I’m not seeing anything obvious in the logs that could point me towards a fix. And now that my code’s public, it’s much easier to ask for help.

My “star” Stan model is run in parallel for eight demographics. Seven of them complete without error, but STDOUT for the eighth’s sampler portion contains this oddity:

# [A TONNE OF OUTPUT OMITTED]
8093.835: Chain [1] Iteration: 10000 / 10000 [100%]  (Sampling)
8093.843: 
8103.613:  Elapsed Time: 8086.29 seconds (Warm-up)
8103.613:                7.494 seconds (Sampling)
8103.613:                8093.79 seconds (Total)
8103.613: 
# [STILL MORE OUTPUT OMITTED]
8898.718: Chain [2] Iteration: 10000 / 10000 [100%]  (Sampling)
8899.392: 
8919.936:  Elapsed Time: 8227.19 seconds (Warm-up)
8919.936:                672.141 seconds (Sampling)
8919.936:                8899.33 seconds (Total)
8919.936: 
# [SKIPPING PAST CHAIN 3]
9952.118: Chain [4] Iteration: 10000 / 10000 [100%]  (Sampling)
9953.434: 
9953.439:  Elapsed Time: 8695.15 seconds (Warm-up)
9953.439:                1258.23 seconds (Sampling)
9953.439:                9953.38 seconds (Total)
9953.439: 

It’s a bad sign when one chain completed two orders of magnitude faster than the others! My code discards the raw posterior CSVs output by Stan’s NUTS sampler, but I can still isolate one chain’s output in the compressed Parquet file it keeps around.

>>> posterior = pd.read_parquet( "excess_deaths.Age at time of death, 65 to 84 years.Males.parquet" )
>>> posterior['exp_amp'][:32:4]  # nothing special, I get the same behaviour from other parameters
0     987.64845
4     987.64845
8     987.64845
12    987.64845
16    987.64845
20    987.64845
24    987.64845
28    987.64845
Name: exp_amp, dtype: float64
>>> posterior['exp_amp'][1:32:4]
1     1162.0122
5     1240.2681
9     1294.4327
13    1428.2402
17    1114.5056
21    1081.0909
25    1453.9135
29    1326.2621
Name: exp_amp, dtype: float64

Chain one seems to have gotten “stuck,” pumping out the same value over and over again. One possible explanation is a NaNin the posterior, but…

>>> sum( posterior.isna().sum() )
0

… there are none to be found. The gradient may have become a NaN value anyway, but if that’s the case I’d expect the log probabilities to be NaN for that chain, and the Parquet files include the lp__ column. I’d also expect a much shorter burn-in period than is observed. There are some infinite values in the posterior, but that’s because one variable in “generated quantities” returns inf if nothing’s amiss. Chains two through four have ample inf values, and yet nothing’s wrong there.

The STDERR log of the sampler run isn’t illuminating, either.

$ bzcat excess_deaths.Age\ at\ time\ of\ death\,\ 65\ to\ 84\ years.Males.sampler.err.bz2 | grep Exception
0.027: Exception: normal_lpdf: Scale parameter[5] is -4.18656e+28, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.031: Exception: normal_lpdf: Scale parameter[4] is -1.21467e+29, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.031: Exception: normal_lpdf: Scale parameter[2] is -9.44079e+28, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.031: Exception: normal_lpdf: Scale parameter[393] is -7.26158e+28, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.032: Exception: normal_lpdf: Scale parameter[2] is -4.16264e+28, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.033: Exception: normal_lpdf: Scale parameter[3] is -29.9696, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.033: Exception: normal_lpdf: Scale parameter[1] is -7.19344e+27, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.033: Exception: normal_lpdf: Scale parameter[356] is -1.43916e+29, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.035: Exception: normal_lpdf: Scale parameter[393] is -1.22123e+29, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.035: Exception: normal_lpdf: Scale parameter[4] is -198.739, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.036: Exception: normal_lpdf: Scale parameter[421] is -2091.82, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.080: Exception: normal_lpdf: Scale parameter[20] is -19.7683, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
0.224: Exception: Found a inf value in 'baseline', rejecting sample. (in 'excess_deaths.stan', line 331, column 1 to column 75)
4.948: Exception: Found a inf value in 'baseline', rejecting sample. (in 'excess_deaths.stan', line 331, column 1 to column 75)
5.624: Exception: Found a inf value in 'flu_dispersion', rejecting sample. (in 'excess_deaths.stan', line 338, column 1 to column 87)
5.626: Exception: Found a inf value in 'exp_amp', rejecting sample. (in 'excess_deaths.stan', line 333, column 1 to column 73)
9953.439: Exception: normal_lpdf: Scale parameter[314] is -2.38664e+06, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)

That may look nasty, but other demographics can have very similar exceptions in their output. Nothing stands out to me.

I’ve looked around in this forum for others running into the same issue, and while there is some precedent their problem was usually systemic and could be solved by eyeballing the code. My model is much too complex for that. As pointed out earlier, this problem also only shows up for one demographic out of eight, and only popped up after months of A-OK operation; but what I haven’t yet mentioned is that I have a “variant” system that currently spams those eight demographics 44 more times with slightly tweaked parameters. The only significant difference between all those successful prior runs and the current one are updated datasets (at the start of the pipeline, it downloads the latest versions off the public internet). The only code changes have been isolated to the Python chart-generating code, and that’s well downstream. This points to something rare and intermittent causing the issue, be it within my model or CmdStan itself.

Thankfully, I put a lot of thought into how to replicate my work, which also makes replicating this bug a snap:

$ docker pull hjhornbeck/excess_deaths_canada_2010_2023:DEBUG_IMAGE
$ docker run -it --rm hjhornbeck/excess_deaths_canada_2010_2023:DEBUG_IMAGE /bin/bash
user@8c639c5cd547:~/excess_deaths_canada_2010_2023$ python3 src/model_excess_deaths.py --stripe-id 5 --stripe-count 8 --debug

If all three commands run successfully, then approximately 9,953 seconds later you should have a finished sampler run with a broken posterior. Change “5” to any other number between 0 and 7, and you’ll be working on another demographic (derived/categories.csv provides a handy list); eliminate all the options after the Python source file, and up to 32 cores of your computer will start number-crunching on every demographic. If you want to save the original posterior CSVs, the function that manages the sampler run has a parameter to handle that but you’ll have to either manually edit the code in src/model_excess_deaths.py or enable pdb and intercept the call just before it happens.

Just want to examine the aftermath, without the wait?

$ docker pull hjhornbeck/excess_deaths_canada_2010_2023:DEBUG_POST_IMAGE

Docker images are mounted to the filesystem when run, so with root access you can poke at what’s inside from the outside. If you’d prefer to work from the inside, it’s not hard to write aDockerfilethat adds whatever tools you need to the existing image.

This is usually a sign that there are missing constraints in your model and it doesn’t have support over all the parameter values that satisfy the constraints.

Agreed. It usually means it has too large of a step size and it’s just rejecting every iteration (which seems to be the case here), or the other chains have found a much better lower step size to explore areas of high curvature.

Unfortunately, gradients can be NaN even with log densities are fine.

Bingo. This is a bug in the model implementation. Models need to be coded in Stan so that scales are never negative. So you just need to figure out what this parameter is and why it’s becoming negative. If it’s the result of something like a regression, you just need a log link function, so apply the inverse, exp(). If it’s just a raw parameter, it’s missing <lower=0> on the declaration.

I figured, but it’s good to get confirmation of this.

Bingo. This is a bug in the model implementation. Models need to be coded in Stan so that scales are never negative.

I agree in general, but I’m pretty sure that’s not the problem here. For one thing, here’s what I get when I grep for “but must be” for the men 85+ demographic, which doesn’t “stick:”

0.029: Exception: normal_lpdf: Scale parameter[66] is -8.94196, but must be positive! (in 'excess_deaths.stan', line 649, column 8 to column 144)
0.031: Exception: normal_lpdf: Scale parameter[42] is -0.0627468, but must be positive! (in 'excess_deaths.stan', line 649, column 8 to column 144)
0.031: Exception: normal_lpdf: Scale parameter[11] is -0.0366388, but must be positive! (in 'excess_deaths.stan', line 649, column 8 to column 144)
0.033: Exception: normal_lpdf: Scale parameter[37] is -1.12927, but must be positive! (in 'excess_deaths.stan', line 649, column 8 to column 144)
0.033: Exception: normal_lpdf: Scale parameter[39] is -0.313821, but must be positive! (in 'excess_deaths.stan', line 649, column 8 to column 144)
0.033: Exception: normal_lpdf: Scale parameter[43] is -0.0668501, but must be positive! (in 'excess_deaths.stan', line 649, column 8 to column 144)
0.033: Exception: normal_lpdf: Scale parameter[15] is -0.19707, but must be positive! (in 'excess_deaths.stan', line 649, column 8 to column 144)
0.062: Exception: normal_lpdf: Scale parameter[14] is -0.574631, but must be positive! (in 'excess_deaths.stan', line 649, column 8 to column 144)
2.221: Exception: excess_deaths_model_namespace::log_prob: pred_pc[2] is -0.00610573, but must be greater than or equal to 0.000000 (in 'excess_deaths.stan', line 343, column 4 to column 31)
2.298: Exception: normal_lpdf: Scale parameter[1] is -6.72626e+246, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
2.381: Exception: normal_lpdf: Scale parameter[1] is -8.51859e+103, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)
9448.843: Exception: normal_lpdf: Scale parameter[1] is -5.28482e+124, but must be positive! (in 'excess_deaths.stan', line 646, column 8 to column 130)

Women 85+ looks worse, there’s even a -NaNvalue in the same line of code. NUTS can handle these invalid inputs, in my experience, provided they don’t happen during initialization and aren’t a common occurrence otherwise. For another, it’s happening at every sample of the thousand samples of the output, which implies I should get a thousand of roughly the same exception near the end. That’s not visible in the logs.

If it’s just a raw parameter, it’s missing <lower=0> on the declaration.

One of the components is indeed missing a <lower=0> in its declaration, but it’s not a raw parameter. Here’s the relevant portion of code:

    // ... in transformed parameters:
    vector<lower=0>[N] pred_pc;               // the model's prediction for per-capita deaths
    vector[N] dp_deaths_pc;                   // drug poisoning deaths, per-capita, interpolated

    // ... skipping ahead ...

        // determine the likelihood of the model, pre- and post-COVID
        vector[C_idx-1] prop_pre = pred_pc[:C_idx-1] + dp_deaths_pc[:C_idx-1];
        deaths_pc[:C_idx-1]      ~ normal( prop_pre + flu_pc[:C_idx-1] + excess_pc[:C_idx-1], sigma_fix + (sigma_poi.*prop_pre) );
        // ... ^^^ the above is line 646

dp_deaths_pc is almost certainly the cause of that exception. It’s generated by interpolating a cubic spline with values pulled from a multivariate Gaussian summary of another model’s posterior. Because the model favours one end of the spline being near zero at one extreme, and the next knot down has an upward slope, the spline can sometimes dip below the zero line even though all knot values and second derivatives are positive.

So why didn’t I add that lower constraint? Over a year ago I ran into a strange problem: adding parameter constraints would either cause failures to initialize the model, or R-hat to spike upwards. My best guess was that adding a parameter constraint both changed the values created by Stan’s default initializer, and added an additional term to handle the change of variable. Though it wasn’t much of an effect, one or both those might have been enough to toss the NUTS sampler into a nasty part of the parameter space. It made me shy to add constraints for a bit; switching to initialization by Pathfinder solved the initialization woes, and so I restored the bounds for most parameters. dp_deaths_pc’s problems were unrelated to that, though, and the experience of running the model without bounds led me to benchmark what was faster: adding those bounds, or catching the negative values post-hoc by analyzing the cubic spline in generated quantitiesand flagging whether it was always-positive via an integer. In practice, the post-hoc test turned out to be faster, as for most demographics the failure rate is under 2% (the biggest exception is men 44 and under, where it’s a whopping 34%).

Flipping through the logs shows I made that change almost a year ago, however. Since then I’ve done two full audits of that model’s code and greatly boosted performance. While I’m pretty sure adding those constraints won’t fix this “stuck” sampler issue (more on that in my next comment), it’s about time I revisited that design decision to see if it still holds up.

Apologies for wandering away from this thread for a bit, but I was waiting for some sampler runs to finish. Alas, I botched the launch so they were running at maybe 25% speed, so I’ll have to do without.

In the past few days I’ve done a bit more diving on this, and while I don’t have a smoking gun I think I’ve got a good idea of where the problem lies. The step size looked odd for the “stuck” chain, though in this case it seemed too small rather than too big. I tweaked my code so I could optionally save the warm-up samples, and started charting the “extra” parameters that CmdStan slips into the posterior, likeaccept_stat__andstepsize__.

I’ve made the problematic chain less transparent than the rest. As the timing data suggested, the acceptance statistic was quite healthy during the warm-up phase, but plunged to a solid 0% once sampling started.

You can see the different warm-up phases of Stan’s NUTS implementation here, at least to some extent. The “stuck” chain’s step size is right in line with those of all but one other chain during the warm-up, but winds up the smallest of the bunch by quite a distance while sampling.

This zooms in on phase III of the warm-up. The “stuck” chain wanders around, which is in line with what we expect, but near the end it happens to plunge downwards. This seems to skew the sampling phase’s step size downward.

My best guess, then, is that at the end of the warm-up phase one of the chains wandered into a narrow local maximum. This led to phase III cranking down the step size to stick within those bounds, and after it finished it happened to pick a fixed value too small to allow the sampler to escape. Thus the sampling phase was doomed to repeat the same value over and over again. Note in both of the prior figures some chains had unusually low step sizes, as well, but the accept_stat__ chart shows they didn’t “stick.”

I did a sampler run where I manually rigged stepsize_jitter to be a small value above zero, keeping everything else constant. You can see another chain had an even smaller step size during most of the warm-up phase, but all chains wound up back in the “safe zone” during the sampling phase. It’s not conclusive proof, I’d still like to see a run where one chain had a small step size during the sampling phase, but for now it suggests a couple of potential fixes.

Tweaking how phase III settles on a step size might help, but given how rare this problem is you could just as well argue it’s working A-OK, I just happen to have a messy model and a bad random number sequence that pushes NUTS into a rough part of the parameter space. Alternatively, after staring at the chart for lp__ I’ve become convinced I take waaaay too many warm-up samples. The likelihoods settle into a good ol’ fuzzy caterpillar somewhere around the 1000 mark, but to be cautious I’ve compromised on 4000 samples. The stuck chain has indeed freed itself, though again this is more of a band-aid than an actual fix. Thirdly, I’ve tweaked my code to allow users to modify stepsize_jitter, on the theory that the non-constant step size during the warm-up phase is why other chains didn’t “stick” despite having similar step sizes.

For the moment I’m doing sampler runs with stepsize_jitter locked to zero, to suss out how rare this problem is. After that, I may make 0.01 the “official” default for my code.

FYI, I got back the results of adding that constraint in. This isn’t a clean test, as I had also lowered the warm-up sample count to 4000 (it was 9000 before), but for the men 65-84 demographic:

>>> summary = pd.read_csv('posteriors/excess_deaths.Age at time of death, 65 to 84 years.Males.summary.csv.bz2',comment='#')
>>> summary.R_hat.describe()
count    11081.000000
mean         2.890817
std          0.559878
min          1.041669
25%          2.504507
50%          2.962739
75%          3.297728
max          4.358678
Name: R_hat, dtype: float64

Ouch! The error log during the sampler run went from having 13 exceptions to 19,923, most of which look like this:

2131.888: Exception: excess_deaths_model_namespace::log_prob: dp_deaths_pc[450] is -1.69055e-07, but must be greater than or equal to 0.000000 (in 'excess_deaths.stan', line 344, column 4 to column 36)

What’s likely happening is that NUTS’s pure HMC portion is prone to wandering into sections of the parameter space where the cubic spline in dp_deaths_pc has negative values. With that constraint in place, NUTS immediately stops the trajectory, preventing it from properly exploring the space.

Without that constraint, NUTS carries on and later adds pred_pc to dp_deaths_pc; since the former is usually larger than the latter, the standard deviation is usually positive and exceptions on line 646 are rare. dp_deaths_pc must still be positive-everywhere for the sample to be valid, but my checks for that are in generated quantities, after HMC has long since finished. Thus NUTS can temporarily wander into invalid parts of the sample space while generating a sample, giving it much more freedom to explore. Even if it picks an invalid sample, I can filter those out after-the-fact. That does mean I wind up with the occasional negative standard deviation on line 646, but Stan’s NUTS implementation doesn’t seem to mind.

That does mean I misremembered this bit, unfortunately:

… and the experience of running the model without bounds led me to benchmark what was faster: adding those bounds, or catching the negative values post-hoc by analyzing the cubic spline in generated quantities and flagging whether it was always-positive via an integer. In practice, the post-hoc test turned out to be faster …

Running with the constraint in place is actually faster, however the sample it leads to is terrible. Right conclusion, wrong justification.

I’ve completed those test runs, and have an idea of how common this problem is. I’ve tested 1,504 possible combinations (47 variants times eight demographics times four chains), and of those this issue popped up twice. The first is the one I mentioned at the top of this post, the second can be recreated via this diff:

diff --git a/original/models.yaml b/parameters/models.yaml
index f18131d..023838c 100644
--- a/original/models.yaml
+++ b/parameters/models.yaml
@@ -62,11 +62,11 @@ excess_deaths:
 
   stan_params:
     version:       2.37.0        # (2.37.0)
-    iter_warmup:   9000          # (9000)
+    iter_warmup:   4000          # (9000)
     iter_sampling: 1000          # (1000)
     max_treedepth:   12          # (12)
     adapt_delta:      0.95       # (0.95)
-    seed:           120          # (120)
+    seed:          1729          # (120)
     threads:          4          # (4)
     chains:           4          # (4)
     sig_figs:         8          # (8)

This time around, it’s chain 4 for the women 85+ demographic. I didn’t check the actual posterior samples yet, but the other symptoms are there: two orders of magnitude faster completion, NaN for all R-hat values, nothing odd in STDERR or STDOUT. I did spot some other timing anomalies, where the warmup or sampling phase for one chain would take twice or half as long as the others, but they don’t look related.

I’m currently gathering posterior samples from the other failure, as well as another full run with stepsize_jitter = 0.01. The latter will take about four days, alas.