NUTS misses U-turns, runs in circles until max_treedepth

Is it something like this? Where I’ve drawn lines between the sets of points being checked for U-turns?

leapfrog_checks

I’m still thinking in terms of the original uturn criteria btw. I’ve been too lazy to learn the psharp stuff.

(So halfway through drawing this I realize the picture is wrong, it should’ve been drawn entirely in the momentum space where R1+R2=R3 does in fact hold, but I think this still gives you the right idea.)

It looks like this:

cycle

Let’s say you have built a tree (going from P1 to P2) and because it is not yet a u-turn (the momenta P1 and P2 are not quite orthogonal to R1) you build a second tree (going from P3 to P4). It is also valid so your whole tree is acceptable.

But do you stop here or continue to the next treedepth?

The U-turn criterion is that you check if the momenta at the ends P1 and P4 are parallel or antiparallel to the average momentum R3. Here they are antiparallel and tell us to stop because the picture would have been too confusing if the ends overlapped. But it is possible to imagine a situation where they do cross and R3 is in the opposite direction to what we see here. Then standard NUTS would continue.

The proposed criterion is to also check if P2+P3 is antiparallel to R, which is acceptable here but would bring us to a stop if the ends were crossed (i.e precisely when standard NUTS does not see the u-turn.)

But how are you getting the start of the right subtree? What is the modified build_tree code?

It’s just p_sharp_dummy. p_sharp_dummy starts its life as the right endpoint of the left subtree but because it’s not normally used it gets reused as the left end of the right subtree. I store a copy of it in between. Build_tree proceeds deterministically so there’s no possibility of confusion. (In my first implementation I forgot that transition() moves nondeterministically and the bug made the trajectories twice as long.)

Here’s the full code.

      bool build_tree(int depth, ps_point& z_propose,
                      Eigen::VectorXd& p_sharp_left,
                      Eigen::VectorXd& p_sharp_right,
                      Eigen::VectorXd& rho,
                      double H0, double sign, int& n_leapfrog,
                      double& log_sum_weight, double& sum_metro_prob,
                      callbacks::logger& logger) {
        // Base case
        if (depth == 0) {
          this->integrator_.evolve(this->z_, this->hamiltonian_,
                                   sign * this->epsilon_,
                                   logger);
          ++n_leapfrog;

          double h = this->hamiltonian_.H(this->z_);
          if (boost::math::isnan(h))
            h = std::numeric_limits<double>::infinity();

          if ((h - H0) > this->max_deltaH_) this->divergent_ = true;

          log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h);

          if (H0 - h > 0)
            sum_metro_prob += 1;
          else
            sum_metro_prob += std::exp(H0 - h);

          z_propose = this->z_;
          rho += this->z_.p;

          p_sharp_left = this->hamiltonian_.dtau_dp(this->z_);
          p_sharp_right = p_sharp_left;

          return !this->divergent_;
        }
        // General recursion
        // XXX replaces p_sharp_dummy
        Eigen::VectorXd p_sharp_lmiddle(this->z_.p.size());

        // Build the left subtree
        double log_sum_weight_left = -std::numeric_limits<double>::infinity();
        Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size());

        bool valid_left
          = build_tree(depth - 1, z_propose,
                       p_sharp_left, p_sharp_lmiddle, rho_left,
                       H0, sign, n_leapfrog,
                       log_sum_weight_left, sum_metro_prob,
                       logger);

        if (!valid_left) return false;

        // Build the right subtree
        ps_point z_propose_right(this->z_);

        // XXX p_sharp_dummy switches roles here, make a copy instead
        Eigen::VectorXd p_sharp_rmiddle = p_sharp_lmiddle;

        double log_sum_weight_right = -std::numeric_limits<double>::infinity();
        Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size());

        bool valid_right
          = build_tree(depth - 1, z_propose_right,
                       p_sharp_rmiddle, p_sharp_right, rho_right,
                       H0, sign, n_leapfrog,
                       log_sum_weight_right, sum_metro_prob,
                       logger);

        if (!valid_right) return false;

        // Multinomial sample from right subtree
        double log_sum_weight_subtree
          = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right);
        log_sum_weight
          = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);

        if (log_sum_weight_right > log_sum_weight_subtree) {
          z_propose = z_propose_right;
        } else {
          double accept_prob
            = std::exp(log_sum_weight_right - log_sum_weight_subtree);
          if (this->rand_uniform_() < accept_prob)
            z_propose = z_propose_right;
        }

        Eigen::VectorXd rho_subtree = rho_left + rho_right;
        rho += rho_subtree;

        // XXX the new check is here
        if ((p_sharp_lmiddle + p_sharp_rmiddle).dot(rho_subtree) < 0)
          return false;

        return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree);
      }

Here’s a better picture

circle

Black dots are the position samples. Red lines are momenta along the trajectory. Green lines are summed momenta over the subtrees. The NUTS criterion says that the tree grows until either end of the green line makes at least 90 degree angle with the red line.

(The labels refer not to the points but to the lines, which are understood as vectors. I didn’t draw the arrowheads because the picture feels already too busy. Compare to the other picture if you’re confused.)

In addition to the NUTS criterion (P1 vs R3 and P4 vs R3) there are four possible comparisons one could make:

  1. P2+P3 vs R3. This is the most permissive in that if it triggers then these all will trigger. This is what I implemented.
  2. P2 vs R3 and P3 vs R3
  3. P2+P3 vs R1 and P2+P3 vs R2
  4. P2 vs R2 and P3 vs R1. This is the most strict, it will trigger if any of these triggers. This is also how Ben Bales interpreted it in his picture.

The first two can still miss some U-turns while criteria 3 and 4 guarantee (up to numerical rounding errors) that wraparound U-turns are always seen.

I don’t know which one is the best but I guess my current preference is 3.

Your argument is that one subtree might just miss the optimal trajectory, and hence the trajectory will have to double by adding a new subtree and then that expanded trajectory might miss again, inciting something of an infinite loop? I can see the first subtree just missing but I don’t see how the first expansion wouldn’t trigger the termination criterion – the trajectory end points should be contracting towards each other causing the termination condition to pass.

Let’s presume that there is an issue with the first expansion because the check against the ends of the doubled trajectory don’t trigger the termination criterion.

You’re then proposing additional checks beyond those at the boundaries of each subtree, right? For example if we wanted to push just a little bit further then we might check against the left-most state and right-most states in each subtree (left image) or the first state and the average momenta between the ends of two subtrees being merged (right image).

I can see how this would help avoid a doubling (at the cost of more checks and carrying around more than just the subtree end points, at least one more state at each level of the recursion) and is worth exploring.

1 Like

The endpoints are contracting towards each other and then overshoot before the condition is checked. Yes, this should be impossible as two less-than-U-turns cannot add up to a more than a full circle. But the one leapfrog jump separating the subtrees allows them to rotate towards each other. Your pictures are misleading because they depict the trajectories as continuous. Look at the last picture I posted above and consider the fact that NUTS validity of the subtrees is consistent with any angle whatsoever between the vectors P2 and P3. Checking if R3 is antiparallel to P1 or P4 is not enough.

By P2 + P3 vs. R1, do you mean take one additional leapfrog step? Or do you mean check P2 vs R1 and P3 vs R1?

P2 vs R2 and P3 vs R1 look like the original u-turn condition to me the way I’m reading this.

The thing I was talking about looks like the left image in @betanalpha’s third set of plots.

Thanks everyone for taking the time to draw pictures, btw.

I mean check the vector sum of P2 and P3. This is equivalent to checking both P2 vs R1 and P3 vs R1 but continuing even if one of them is negative as long as their sum is positive.

The original u-turn condition for the left subtree is P1 vs R1 and P2 vs R1.
The original u-turn condition for the right subtree is P3 vs R2 and P4 vs R2
The original u-turn condition for the full tree is P1 vs R3 and P4 vs R3.

P2 vs R2 and P3 vs R1 are new checks ignored by the original u-turn criterion which implicitly assumes that P2=P3 or “close enough”. That’s true in the limit of vanishing stepsize but when the entire tree is only 16 leapfrog steps it’s way too crude.

Thank you for prompting me to do so. It’s been super helpful for clarifying my own thinking about this.

I know continuous picture is an idealization, but I’m trying to get a handle on exactly what pathology might be present to ensure that whatever solution we might consider not only resolves that pathology but also doesn’t have any negative influences, such as premature termination of other trajectories. Your initial investigation is intriguing and worth the follow up but we have to figure it out proper before considering a solution!

Excuse the continued to idealized cartoons, but this is the circumstance you thin might be happening, right?


The gap on the right figure is due to the leapfrog between the two subtrees, which pushes the two subtrees to overlap on the other side of the level set.

Wouldn’t the proper checks then be


In other words check between any two subtrees before merging them?

Okay, the arguments in this thread have been more difficult to follow than they should have been and that is entirely my fault. Let’s go over some mistakes I’ve made.

Firstly, I’ve provided both code and a picture but I’ve used different names for the same objects. Here’s a quick reference:
P1 = p_sharp_left
P2 = p_sharp_lmiddle
P3 = p_sharp_rmiddle
P4 = p_sharp_right
R1 = rho_left
R2 = rho_right
R3 = rho_subtree

Next, about that picture… it’s supposed to illustrate the new criterion yet depicts a situation where the criterion is redundant and doesn’t even trigger! That’s not helpful, that’s just confusing. Let’s walk through the sampling process step-by-step.

The idea is to build the tree until you find a u-turn. You build a tree with four leapfrog steps. It looks like this:
left
There is no u-turn here, so you continue and build a second tree. It looks like this:
right
(Please ignore the fading last node of the first tree, it is not part of this tree.) There is no u-turn here either, so you merge the trees into single big tree. Now the question is, does the big tree have a u-turn? The No-U-Turn criterion we all know and love assumes the tree looks like this:
wrong
and tells you: “If the ends of the tree (green) are moving in the general direction of the momenta summed over the whole tree (blue) then there is no u-turn.”

But the tree looks like this:
fulltree
The No-U-Turn criterion gives you the wrong answer.

Reading Betancourt’s posts again made me understand how to think about this in terms of continuous trajectories. The trajectory is continuous but segmented into timesteps and the u-turn condition is checked using the midpoints of the segments, not their boundaries. Suppose the speed and stepsize are such that it takes 14.5 timesteps to go around the great circle once. Once a trajectory is 8 segments long it starts to contract and should be stopped. However, the discretization error effectively chops off half a segment from both ends and makes the actual comparison with points only 7 segments apart. These are still expanding and the trajectory is doubled to 16 segments. The discretization error again removes one segment and we are left with 15 segments. But since that’s more than 14.5 it wraps around and the trajectory is expanding once again. No evidence of a u-turn has been seen. My proposed solution has been to check if the motion during the “skipped segment” in the middle of the trajectory is in fact antiparallel to the total motion between the checked endpoints.

Moving on.

When you file a bug report you should always include a short piece of code that demonstrates the issue. Something which can be found nowhere in my posts. Oops.

import pystan
import pystan.diagnostics
model = pystan.StanModel(model_code="""
  data {
    int N;
  } parameters {
    vector[N] x;
  } model {
    x ~ std_normal();
  }""")
fit = model.sampling(data=dict(N=200), chains=20,
                     control=dict(max_treedepth=12),
                     refresh=0, check_hmc_diagnostics=False)
pystan.diagnostics.check_treedepth(fit, verbose=3)
WARNING:pystan:24 of 20000 iterations saturated the maximum tree depth of 12 (0.12 %)
WARNING:pystan:Run again with max_treedepth larger than 12 to avoid saturation

I didn’t include rng seed because Stan doesn’t guarantee reproducibility across machines anyway. While the code doesn’t fail every time it does fail often enough that you should have no problem seeing the treedepth warning. You can make the warning happen almost every time by using the default max_treedepth but I wanted the most dramatic example.
The treedepth at which trajectories circle around for this model is 4 because that is what the chains that didn’t blow up have and also what you get with N=500. Those 24 out of 20,000 transitions apparently managed to do 12-4=8 doublings too many which means looping around the origin at least 128 times without noticing a single problem. This is alarming.

And one more thing. I mentioned four alternative new criteria and said that two of them could still miss some u-turns. That was just a guess. Here’s a proper counterexample, a u-turn that the first criterion says isn’t there. (Intermediate points omitted.)


Here R3 and P2 are both downward and even though P3 goes upward it’s so small it has little effect on the average direction at the point where the subtrees join.
It means (p_sharp_lmiddle + p_sharp_rmiddle).dot(rho_subtree) > 0 and my original proposal would not terminate this trajectory.

Looking at the picture we see that there is no room for P3 to align with R3 without also antialigning with R2 (and violating the u-turn condition on the right subtree). If the lengths of R1 and R2 were reversed we could mirror the image and get the same constraint for P2. So checking if either of p_sharp_lmiddle.dot(rho_subtree) or p_sharp_rmiddle.dot(rho_subtree) is negative is in fact a loophole-free condition, contrary to my guess.

I know you know. It just irks me when the whole problem is the discretization error but you draw a smooth curve and say it can’t happen. Anyway, now that I get the “segmented curve” approach I mentioned in the post above it doesn’t bother me at all.

This is a serious concern. I fear that S-shaped trajectories are in danger. Figure 2 in the Hoffman and Gelman 2014 NUTS paper is a trajectory zig-zagging down a narrow corridor. I think it survives my proposal but it does make me worried.

What’s your opinion on the “momentum polygon” representation I’ve been using? The momentum vectors are placed end-to-end. Subtree momenta are exact and the u-turn criterion is a straighforward statement about the internal angles. Initially I thought it was also accurate in position space when positions are placed in the center of their respective momenta but since then I’ve gotten myself confused again and don’t know anymore.

Yes, but what test do you propose? You can’t compare the p_sharps directly to each other and even if you could it wouldn’t help. All the cheap checks are of the form p_sharp_*.dot(rho_*) > 0 and any symmetric combination of those should be sufficient.

1 Like

I think a discretization argument like this is the way to go, but isn’t it true that because of this you can have two half u-turns and get more than a full uturn?

This is because when we put two uturns together, there’s an extra leapfrog in between them? Is that what you’re saying?

Also, I don’t see a reason to separate the momentums and positions in the pictures.

In a lot of leapfrog integrators you have momentums at half-steps and positions at full steps, but Stan needs positions and momentums at the end of every step. With regards to the U-turn criteria I don’t think we need to be thinking about these half-steps.

A couple of the states in the trajectory disappeared in the fourth image. I thought it might have something to do with this.

I’m not so sure. We’re just adding checks that would have also been there in the additive scheme, right?

It seems like the problem is that we’re too stiff in the distances between points we’re checking over. So we check between certain points L^1 steps apart, then L^2 steps apart, then L^3 steps apart, etc.

And we want to tactically do some checks of lower length? Especially to avoid these things that double a bunch? @avehtari what’s your take on this?

I think we could fix the immediate problem here with checks like on the left of the third figure here: NUTS misses U-turns, runs in circles until max_treedepth .

When I get home from this trip I’m going to have to put together a test suite of models to systematically check the performance changes.

It’s tricky to try to align the momenta spatially because of the half-steps in the leapfrog integrator – the distance between points isn’t aligned with either the final or initial momenta. I haven’t wrapped my head around a deeper geometry meaning. Doesn’t mean that something’s not there, just that I’m exhausted from teaching at the moment.

My intuition is to stick with something that compares subtree ends points to the rho along that subtree because it makes the necessarily reversibility easier to verify – if the checks aren’t sufficiently symmetric then the sampler will pick up an appreciable bias.

Checking between subtrees being merged may not be of much help unless a u turn just happens to be snuck into that niche. I think the more useful check will be between the first points of the merged subtrees (and by symmetry the last points of the merged subtrees) which would terminate trajectories that just missed the u turn after the last doubling. Visually,

best_extended_checks

Again, I’ll have to verify on a test suite to provide some proper evidence.

I don’t know what kind of checks you’re thinking of here or what the additive scheme is. I’ve been badly inconsistent on what I propose because there’s a couple of variants and the first one I tried is probably not the best. So I’ll try to be clear here. Currently in build_tree() there is a vector called p_sharp_dummy which is there only so it can be passed to recursive calls. I propose it is replaced by two variables p_sharp_lmiddle and p_sharp_rmiddle to separate it’s two uses. The last line of build_tree which currently reads

return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree);

is then replaced with

return (compute_criterion(p_sharp_left, p_sharp_right, rho_subtree)
     && compute_criterion(p_sharp_lmiddle, p_sharp_rmiddle, rho_subtree));

This is sufficient to prevent the looping.

In other news, I finally did what I should have done a week ago and just hacked Stan to dump the trajectory to std::cout when it goes bad. Here’s what a pathological trajectory actually looks like:
bad
And some more with a larger stepsize. Orange arrows indicate the momentum at the step and the red arrow in the center is the sum of all momenta.


It is perhaps surprising that the ends do not meet. One must remember that the U-turn criterion in Stan is purely a condition on the momenta; the positions are irrelevant. The trajectory stops if the momentum at either end makes at least 90 degree angle with the sum. In all these the momentum has already looped around even though positions are yet to pass.

1 Like

It’s one of two ways to build the set of trajectories to sample from outlined in: https://arxiv.org/pdf/1701.02434.pdf and https://arxiv.org/pdf/1601.00225.pdf

Takes a little to digest cause there’s lots of parts, but you should go through it cause how you build the tree and what U-turn checks you need are super connected.

Well don’t dig too far back in my post history haha.

Nice plots. I like them. Did you actually get a 2d example to do this? Or is this picking a couple dimensions out of a much higher dimensional problem?

I guess the rhos that Betancourt is talking about are the sums of momentums. But change in position position is the integral of momentum, and just summing the momentums with a fixed stepsize along a leapfrog trajectory is going to give you a good approximation to that change in position up to a scaling factor, and we only care about directions.

So all the red arrows point in the same direction as a vector between the unconnected endpoints. Given that and the fact that I don’t know the Riemannian stuff, I’ve happily stuck with the old U-turn criteria. The differences are in A.4.2 here: https://arxiv.org/pdf/1701.02434.pdf

My buddy @arya wrote a NUTS implementation in R for experimenting with this sorta algorithmic stuff. I’ve been working on it some more. Haven’t gotten it finished yet (it seems to work – just badly organized at this point), but it takes in Stan models/data, and the NUTS implementation as written is about 200~ lines of non-recursive R that makes it easy to get debug info out :D. Gotta review a couple pull reqs first, but I’ll try to get this up within a month (hardly soon, but sigh).

Rewriting one for yourself might be worth it too since you’re getting in deep on this. I think I was down on rewriting it when Arya originally did it cause it looked so complicated, but Betancourt did a good job writing stuff up so it isn’t so bad in the end.

Yes yes

It’s 2d. Theoretically higher dimensions look the same, just rotated randomly. I didn’t see an easy way to pick the directions of interest in higher dimensions so giving a 2d model a bad stepsize (0.43) was the fastest option. Dimensionality matters only if you let the model choose its own stepsize.

Yes.

I thought so too (cf. “momentum polygons” upthread) but apparently that approximation is quite rough.

That means rho is going in the same direction as endpoint momenta; NUTS thinks the trajectory is straight. Since the final point is behind the initial one a distance vector would go in the opposite direction. But these pictures aren’t all that important as they come from an abnormally large stepsize. The smaller stepsize picture looks as expected.

Really cool!

Too deep, I’d say. Yesterday I spent way too much time staring at the NUTS implementation in Chi Feng’s MCMC Gallery with stepsize 0.43 just see if the problem would show up. (It did.) I was planning on taking a break because I’ve been taking this discussion too personally and @betanalpha already understands the bug better than I do.

1 Like

Nice! This is the picture I had in my head, and one of the first exercises I was going to do in the testing. You have to be careful when drawing even cartoon trajectories because symplectic integrator trajectories lie on closed surfaces – they can’t spiral in or out which significantly constrain their behaviors.

At the same time being on a closed surface (which can be showed to be a perturbation of the exact energy level set) doesn’t mean that the discrete trajectories will close, unless you can set the step size to something that evenly divides the period on that modified level set.

In other words this is pretty much the expected behavior, at least for the one-dimensional Gaussian target.

It all depends on the step size. The reason the two aren’t exactly the same is that the symplectic leapfrog integrator that we use is a second-order method and the change in positions isn’t just a scaled version of the final momenta after each update.

Nice!

I just want to emphasize that this has been a great, constructive conversation, so thanks to all involved. Please feel free to continue – I’ll follow along but won’t be able to really put together rigorous tests until I get back home from traveling/teaching/galavanting in a few weeks.

Lol, 0.425 was the stepsize where I was able to get the radon model to go crazy on me. Wonder if that’s a coincidence or not. That’s cool that it’s working in a 2D model.

Ditto. The internet is intimidating.

Those animations are great