Diagnosting posterior geometry with divergences

Occasionally I see a model with multiple divergences, but plotting the divergences in parameter space doesn’t seem to tell me anything. I wonder whether in some of these cases the divergences actually do concentrate in a particular region of parameter space, but HMC is working well enough so that when Stan samples along the trajectory that terminated in a divergence, it frequently returns a location in parameter space that is far away from where the divergence occurred.

If that makes any sense, then I wonder whether it would also make sense to provide a non-default option to, in the event of a divergence, forgo standard sampling along the trajectory and instead return a point that is at most some fixed smallish number of leapfrog steps from where the divergence actually got flagged. @betanalpha

Edit: this option would need to kick in only after warmup is complete, i think.

3 Likes

Yes, that does sound plausible to me.

It’s a Markov chain, the next transition starts from the last sample. If you pick a diverging sample instead of a proper sample then the chain must continue from the difficult region and is more likely to get stuck. I think the divergent samples must go into a separate diagnostic output stream.
The other question is, what is a good “fixed smallish number of leapfrog steps?” The divergence detection heuristic is somewhat arbitrary and even in theory a divergence does not have a precise location. I guess the easiest solution would be to just pick a random sample from the rejected half of the trajectory. The sampler already records candidate samples as it build the trajectory; when the trajectory blows up you could print the latest candidate instead of just discarding it.

These sort of ideas have been discussed before:

It’s a very interesting problem.

2 Likes

Yes, and I suspect that in certain cases (when HMC works particularly well) there’s even a bias of these points to be further away from the problematic parts of the posterior then a sample from an average non-divergent transition, because the part of the trajectory that touches the problematic region is discarded.

Something I routinely do when diagnosing divergences from standard Stan output is to plot a pairplot that includes treedepth__ along with divergent__ and the parameters. If you have a very strongly problematic region in the posterior then there will be samples which are divergent and have a much smaller treedepth then normal samples. This is because these trajectories were terminated very early due to the big energy errors when entering the problematic region. The samples from these trajectories are always close to the problematic region, because the low treedepth wouldnt have allowed the sampler to move far.
If you don’t see such extra-low treedepth divergences, then the model probably doesn’t have a really difficult region, and increasing adapt-delta should probably suffice to solve the problem.
Sometimes it is necessary need to manually decrease stepsize, because Stans automatic stepsize adaptation sometimes doesn’t “catch onto” these problems if they affect only few samples. I do this by first running a normal adaptation run and then changing stepsize by some fixed factor before continuing to sample. (Implementing this in Stan itself might be a good idea)

@nhuurre Thanks for mentioning my thread about developing better divergence diagnostics. Reminds me that I still need to finish that project. Not that I would need reminding, I actually still often think of it. It’s high on my priority list when I find some time for things like that again. I have a working prototype coded in R, I just need to fix some (unrelated) bugs to make it easy for others to play around with, clean up the code, share it, and write up how it works.

2 Likes

@nhuurre mentioned some of the important points but let me followup on a few of them and explain why Stan behaves the way it does.

“Divergence” does not describe a point but rather an unstable numerical trajectory. In practice this is defined as a numerical trajectory that includes a point where the error in the Hamiltonian function, which will be constant along exact trajectories and should oscillate for stable numerical trajectories, exceeds 1000. Technically this might result in false positives – in very high dimensions the magnitude of the stable error oscillations can in principle exceed 1000 – but I’ve never been able to construct an explicit false positive in practice.

When Stan encounters a diverging numerical trajectory it rejects the current trajectory expansion and then samples from the pre-existing trajectory. This ensures that the Markov transition still asymptotically converges to the target distribution. Asymptotic convergence doesn’t mean much without fast enough preasymptotic convergence, which the appearance of divergences questions, but it’s the safest failure mode. In any case the Stan output should not be interpreted as “a point that diverged” but rather “a point that comes from a trajectory that diverged somewhere”!

Because of this the “points from divergent trajectories” will not always fall into the problematic region of the posterior distribution. If the pre-existing trajectory is long enough then many of the states that can be sampled will actually be very far away from the problematic region. That said the overlap of the pre-existing trajectories from many divergent trajectories, and hence samples from those trajectories, will tend to concentrate near the problematic behavior. It’s the ensemble behavior of divergent transitions that provides the most diagnostic behavior.

As an example consider a posterior distribution that manifests a “pinch” or “spike”, especially in high dimensions. A numerical trajectory that starts far away from the problematic region can be very long before encountering the problematic region and diverging, and consequently a point sampled from that trajectory can be far away from the problematic region. If, however, you look at many trajectories that start far away but eventually fall into the problematic region and diverge then you’ll see them start to concentrate as they get closer and closer to the problematic region. If you sample from all of these trajectories then they will tend to be closer to the problematic region rather than away from it. Moreover, once you start building a trajectory in the problematic region then most trajectories will quickly start to diverge in both directions so that all sampled points are close to the problematic region; only rarely will the sampled momentum be oriented just right so that one end of the numerical trajectory can escape into more well-behaved neighborhoods and pull the sampled states away. This circumstance would trigger the tree depth heuristic @Raoul-Kima mentioned.

I’ve looked at many fits that exhibit divergences, sometimes going through hundreds of pairs plots, and I have yet to see an example where the divergences don’t actually concentrate. Almost always the problem is either not enough divergences to resolve any ensemble behavior – in which case one can just wrong longer Markov chains, or even more Markov chains – or more likely that I’m just not looking at the right pairs plot. See for example the discussion in Section 4.1 of Identity Crisis.

But let’s say that we do want more information from these divergent trajectories. What could we do?

Firstly the architecture of Stan itself drastically limits the possibilities. The Stan Markov chain Monte Carlo library, from input/output to diagnostics, was designed to be modular and support any Markov chain Monte Carlo algorithm. In particular everything is built around a sequence of states from the constrained parameter space augmented with only scalar diagnostic information. Passing along more states would require substantial redesigns in multiple places. Keep in mind is that back in the day we didn’t know what might compete with Hamiltonian Monte Carlo; it’s only with a decade of experience that we found nothing that was even worth implementing. With this hindsight designing around just Hamiltonian Monte Carlo would make more sense, but that is another discussion in of itself.

If we did have a diagnostic output stream available then what might we put there? Remember that a divergence isn’t a point. The divergence diagnostic indicates when a numerical trajectory has passed through a problematic region and become unstable, but not what point passed through the problematic region. In fact because unstable numerical trajectories are so explosive the terminal states can be quite far from the problematic region by the time the error hits 1000.

Consequently once a trajectory has triggered the divergence diagnostic we’d want to look at previous states in the trajectory. Unfortunately Stan’s Hamiltonian Monte Carlo sampler doesn’t keep the entire numerical trajectory in memory at any given time to keep the memory burden as low as possible. To recover those previous states we’d have to rewind the numerical integration, or start again from the initial point, either way requiring more computation.

Even then, what point would we take? Once a numerical trajectory starts to diverge we expect the Hamiltonian error to increase monotonically, but that doesn’t mean that we can just go backwards until the error stops decreasing. In general the transition from unstable to stable behavior will happen before the error bottoms out, and the state with local minimum error could be nontrivially far away from the problematic region.

Ultimately the best diagnostic power is not in any single point but rather the entire numerical trajectory, or at least some subset of the trajectory near the divergent end. There are numerous heuristics that would select reasonable subtrajectories, but we would then confront the input/output problem. How should entire subtrajectories be stored? When will storing multiple subtrajectories, in addition to the standard sampler output, introduce problematic memory burdens? If too many states have to be stored at once should the additional diagnostic output be saved only conditionally, and if so then what should the user-interface be for those checks?

Long story short the jump from the current behavior to enhanced diagnostics might sound straightforward, but there are numerous theoretical and implementation issues that have to be addressed. And then one has to ask just how much more diagnostic information does one obtain over the current diagnostics? In my opinion it’s that as much as many think, especially when the current diagnostics are used carefully.

4 Likes

Do I understand correctly, that you assume here that the problematic region is small in several dimensions and therefore the paths must converge to hit it? If so, the model in the first few posts of the thread by @martinmodrak is in some sense a counterexample, as the problematic region is only small in one dimension and thus almost every path hits is, without having to “converge” (in some sense) in any region.

Could the model I just mentioned from the linked thread of @martinmodrak be an example of this? (can’t run checks on it right now, as I have to reinstall the framework first)

I didn’t talk about the treedepth heuristic (I assume you mean the large treedepth warning) in this thread (but did in my thread liked by @nhuurre), but about the opposite: samples from divergent iterations with super low treedepth. These are usually very close to the problematic region.

Unfortunately divergent models often sample very slowly, which makes it difficult to get enough divergent iterations to diagnose the cause of the divergences. The prototype diagnostics I built in that other threat can help with this problem, because the make the diagnostics more specific, thus requiring fewer samples to be informative. Another thing that helps is that these diagnostics don’t require the iterations to actually be flagged as divergent to find the problematic region, which is helpful for models which throw divergences rarely.
In my testing for some models as little as a single sampling iteration was enough to diagnose the problem, especially when using the entire leapfrog-path for diagnostics as opposed to just a single point from it. But it remains to be seen how well that translates to real world problems.

This is a problem I was thinking about, but haven’t found a good solution for yet. Are there any tools out there that analyze the sampler output to find the right angle to look at the point cloud? Something like running a CCA (or one of its close relatives, not sure which one was the right one right now) on the points, or some other machine learning tool.

Am I hijacking the thread? I’m not sure!

1 Like

That was just an example. A one-dimensional barrier system will behave qualitatively similarly. The probability of a numerical trajectory that starts near the barrier expanding only away from the barrier in each of N trajectory updates is ( \frac{1}{2} )^{N} which decays pretty quickly. In other words very few numerical trajectories will become too large before hitting the boundary.

The higher the dimension of the total space the higher the possibility of trajectories that expand transverse to the barrier, but that doesn’t change much. Either these transverse trajectories don’t curve into the barrier, in which case they don’t register as divergent, or they do, in which case all of the points that could be sampled are still very close to the boundary.

I wasn’t talked about Stan’s existing treedepth warning but rather the heuristic you suggested. Creating separate pairs plots for each recorded treedepth, or coloring the points from divergent trajectories differently, will definitely provide more information!

There’s always more information available in the entire trajectory, not just for diagnostics but also for constructing expectation value estimators. The question is whether that additional information is worth the pretty substantial I/O overhead. The possible improvement in expectation value estimators is pretty small and hence not worth the memory and logistical overhead of storing and organizing all of the trajectories, and I’m arguing that the typical gain in diagnostic power is similar.

I personally don’t believe that this is a productive avenue. The main problem is that even if such a projection could be found (there are many linear projections, but linear projections might not be sufficient and the space of nonlinear projections is even more massive) one would still have to find a way to connect that projection to meaningful parts of the model.

In my diagnostics case study I argue for a more deductive approach – use the structure of the model to identify potentially suspect, but highly interpretable, variables and then prioritize those for diagnostics.

Thanks for your thoughts!

I had some cases where it seemed quite useful when I was working on the diagnostics project in the other thread. Guess I’ll learn more once I find the time to continue that.

Always helps to have concrete examples!