Differentiating conditionals and while loops

Silly question: how does Stan deal with densities specified using while loops and, in particular, conditionals (if-then-else)? I just compiled and ran some code like that without problems. Still, I don’t know how to differentiate those functions. How does HMC obtain a gradient for them?

Having looked a bit at the implementation, I get the feeling that automatic differentiation may just be used to effectively compute a smooth approximation to these functions, but I’d like to double check with the experts.

Thanks!

1 Like

Definitely not a silly question!

Automatic differentiation is used to evaluate the value of a gradient at a point. When there are conditionals and loops, Stan will create an expression graph with computations there and autodiff through that. You might find some insight in the arXiv paper: [1509.07164] The Stan Math Library: Reverse-Mode Automatic Differentiation in C++

Using constructs like while loops and conditionals, it’s very easy to create functions that are non-differentiable. But the reason it’s not disallowed is because it’s possible to do really neat things that couldn’t be done otherwise. Our first implementation of the ordinary differential solver used automatic differentiation through an iterative algorithm in order to get gradients.

I may be reading too much into this, but I think the confusion is that Stan does not build a function that represents the gradient. At the point at which you ask for a gradient, it evaluates the function, builds an expression graph at that point, and then computes the gradient using the chain rule.

Hope that helps.

1 Like

In addition to what @syclik said (all of which is spot on!), it may help to think about the rule for differentiating. If you have

f(u) = cond(u) ? g(u) : h(u)

then

f'(u) = cond(u) ? g'(u) : h'(u)

It just does everything piecewise. Up to you to make sure it’s continuously differentiable at the condition boundaries.

Loops work by just building up more of the expression graph, with the conditions behaving the same way as this in just determining which expression to differentiate.

Right. Thanks, @syclik, @Bob_Carpenter! That’s very clear.

So, if I understand correctly, the interpretation of conditionals like this is OK because they would only cause non-differentiability at a set of measure 0, meaning that we never encounter those points in running HMC?

While constructs, I find a little more difficult to interpret. I guess the assumption is that at any point, the while loop terminates after finitely many passes, resulting in a finite expression tree which we can just differentiate? I guess one difference between Stan and stuff that people might ultimately want to do with a general purpose probabilistic programming language (potentially) is write programs that are actually non-terminating (like servers) but which exhibit intelligent probabilistic behaviour (for instance, do smart crypto stuff). I’m just speculating here, of course, but I’m trying to see how we should construct a semantic framework for understanding probabilistic programming. Another related thing I’d be slightly worried about with while constructs, more so than case constructs, is that we can define densities that become non-differentiable on sets of non-zero measure. (I think of the use of while loops as defining countable sums of densities. Is this right?) I guess it’s just the responsibility of the programmer not to write code like that.

I guess this points to the following rule for differentiating while loops:
Let f be defined through a while loop.

  • If f terminates for input u, expand the while loops and differentiate as usual.
  • If f diverges at u, so does f’.

No. HMC requires C^{1} densities and even cusps over a set of measure 0 will break the algorithm. The way to think about it intuitively is the problem arises not when you land on a discontinuity but rather when you pass a discontinuity. A numerical integrator jump over the discontinuities and ignores them, consequently simulating the wrong dynamics.

This is ultimately what it comes down to with conditionals and loops!

The implicit assumption is that if f converges after a finite time then the derivative of the trace will also converge to the derivative of f by the same time. This is trivial in some case (a loop over a finite mixture for example), and it can be made to work okay in some other cases (it’s how all of our initial linear algebra and ODEs worked and still some of our transcendental functions work). In those latter cases, however, it’s almost always better to work out the gradient by hand.

Thanks, Michael! Clearly, HMC is ill-defined if we cannot differentiate our density. And, I would expect the Hamiltonian dynamics to be discontinuous at a point where the density is discontinuous. However, could we possibly give a meaning to HMC for densities that are differentiable almost everywhere by taking smooth approximations to the density and running HMC for that? (Which would have very rapid behaviour approximating the jump in the limit at the point where the density is discontinuous.) Surely, we would expect the solution to a Hamiltonian dynamical system to depend continuously on the density on suitably chosen spaces? In that case, we could probably still trust the answers of HMC for any density defined using case constructs. That would be useful to know. Or is it very clear that HMC returns non-sense if I give it a density with a discontinuity?

No, that limiting procedure does not smoothly converge to the correct answer (various papers have tried and failed). The correct dynamics requires integrating across delta functions which induces “kicks” in the trajectories. This is all well and good if you land exactly on one of the discontinuities, but the discrete steps of numerical integrators don’t have the required precision and there’s no way to interpolate back towards a discontinuity that you jumped past without breaking the desired properties, such as reversibility of the integrator.

Good to know. So I just shouldn’t trust any results of there’s a discontinuity. I wonder if it’d be useful for users to provide a primitive which gives a reasonably accurate smooth approximation to ifthenelse if users really want to do inference in a model that would otherwise not have a smooth density. Is that something that people ever want to do? Would it give sensible results? In my mind, it shouldn’t matter terribly conceptually if I work with a density with a discontinuity or with some smooth approximation to it. Surely I can’t have enough confidence in my statistical model anyway to have a preference between the two?

It depends on how severe the kink is and whether the sampler ever gets anywhere near it. If it does cause problems, you should get lots of divergence and other warnings because you’ll fall off the Hamiltonian in a non-recoverable way. One place this might be a problem is if a kink messes with the Hamiltonian, but not enough to cause a rejection. In that case, the new points are unlikely to be selected.

Hey, what about continuous function differentiable almost everywhere such as ReLU which is only non-differentiable at x=0? Is the result output from Stan valid?

Relatedly, in Stan, is there anyway to incorporate analytical gradient information? For example, we know the exact gradient of ReLU for x<0 and x>0.

Thanks a lot!

Y

You will tend to get divergences if the leapfrog integrator crosses the discontinuity, even if there is no chance that it will hit the discontinuity. But you may be able to crank adapt_delta up to overcome them if the change is not too rapid through the discontinuity.

Dealing with analytical gradients is not too difficult, but you have to write it in C++ and call it from Stan.

Thanks for the reply and the pointer!
But what I am interested in is slightly different from what you described.
My function is a continuous function without any discontinuity. So there won’t be any missed jump in the leapfrog integrator. The only problem of concern is that it’s non-differentiable at 0. Would that also cause trouble to the leapfrog integrator and hence divergence? Thanks a lot!

Y

I believe that Ben is referring to the discontinuity in the gradient induced by the kink in the ReLU curve. If you have a kink in a high-dimensional space then gradient-based algorithms can sometimes be well-behaved, but in general kinks will tend to span hyperplanes, the algorithms can’t avoid them, and you end up with subtle but important problems. In particular, just about every theoretical result for gradient-based algorithms assumes no kinks anywhere. This includes optimization in addition to Hamiltonian Monte Carlo.

I see your point now! Thanks a lot!