Autodiff: Use LMS to reify an autodiff AST for derivatives

An idea that came up in this paper is to use Lightweight Modular Staging (LMS) as a technique to cause our autodiff system to create an AST that we can perform a variety of optimizations on (some of which sort of come for free with the LMS approach), and even potentially JIT-compile parts of the gradient pass that don’t depend on parameters. I’m going to try to briefly summarize a couple of these ideas here and give pointers for how they could be implemented in our current Math library (leaving most of the existing code alone, I believe). Not to bury the lede, but I think this might be huge for performance.

LMS uses operator overloading to substitute a fake representation of a type in for the real thing and keep track of symbolic operations over the type (Sound familiar? This is what our autodiff does already with var). But instead of pushing pointers onto a stack, it creates an AST for all of the operations that are performed on the fake type. This AST then has some nice properties:

  • Higher level language abstractions from C++ are removed here and only low-level code remains
  • There are some optimizations that come “for free” (e.g. dead code elimination)
  • it is easier to encode other optimizations at a higher level (and there exist easier algorithms for e.g. code motion that use nesting instead of more complicated dataflow analysis)

They also recommend switching from an AST to a “sea of nodes” graph-based intermediate representation, which has some additional benefits such as fairly common sub-expression elimination (CSE) and global value numbering (GVN). More in the Optimizatons link below.

At the end of the day, we can take this AST and either interpret it or use it to generate native code. If we don’t change the tree based on parameters (something we can check statically), we only need to do this once and can re-use the results every leapfrog step. Even if our AST is parameter-dependent, it is likely worth performing at least some of these optimizations on every leapfrog step.

CC @Matthijs @Bob_Carpenter @cqfd


1 Like

Yeah, I read about this approach. It sounds great, but I didn’t bring it up because I don’t have time to attempt an implementation. Does TensorFlow do something similar?

Sort of? From what I can tell, you have to use mostly tensorflow function calls (though there are some overloads for simple arithmetic) that generates a static graph that is then used for the both the function evaluation and the gradient.

The core idea isn’t that different from the idea of expression templates, which many of us are familiar with from working with Eigen. Basically instead of doing the operation you return a type that encodes the operation to be performed, and this eventually forms a tree or graph structure.

What I am proposing here is doing this with our var type and performing some optimizations (and possibly JITting) on the AST before running it.

1 Like

As a prototype, we should be able to do exactly that! Before traversing the autodiff stack, we should be able to get the expression graph.

I’m guessing at runtime, it’s going to hamper our speed, but… maybe we can get the compiler to spit out when the Stan program should be a static graph (no conditionals based on parameters). If it is, then we can look at the expression graph and do all sorts of neat things without having to build out the rest. (Not suggesting we don’t do that! Just trying to think of ways we can do things faster now.)

One thing we can do with that structure is to automatically determine what parts can be split out for parallelization (based strictly on the graph structure). That was stuff I looked at a long time ago for a different project.

1 Like

Yeah, in the rare cases where there are conditionals on parameters I think there will be some extra work at runtime on each leapfrog step but even then I suspect the optimizations to the AST could more than make up for that (and certainly in the static graph case where we can do the optimization pass just once per model fit instead of once per leapfrog step).

CC @Erik_Strumbelj @rok_cesnovar - I think maybe this was something close to what you guys were asking me about once, but I didn’t totally understand how we could apply this to our autodiff at the time. Sorry about that!


If this is at the C++ not the Stan language level, then this branching comes up a LOT for special functions (e.g., Gamma-everything, Beta, Beta-BInomial, etc…) which are typically stitched together from approximations. In those cases the dependence is simple enough you could likely pre-calculate the graphs for both paths and cache them, or treat them as sub-graphs or, since they must be close-enough-to-smooth, maybe treat them as NOT parameter-dependent branches. Some sort of a static C++ concept of “has-branches-but-is-smooth”…

Just throwing this out there.

1 Like

I totally didn’t realize that, but of course that makes sense. Makes an easy optimization pass a little more difficult, but I think you’re right and there might be ways to deal with that… (they may even cover this in the paper eventually, I only skimmed it so far).

What you’re calling an AST is just the expression graph and we’re already storing it as a graph. It’s just a graph with a specialized vari type for each node. The stack just holds the nodes whereas the concrete vari extensions hold the edges in the form of pointers to operands vari.

I’m not sure what you mean by higher-level abstractions from C++ vs. low-level code in the first bullet point. What is low-level code? And do you mean abstractions like template metaprogramming or algorithms in C++?

How do you get automatic dead-code elimination? Right now, we use that stack to manage the dynamic programming of adjoint propagation in the reverse pass. What happens if junk gets thrown onto it that isn’t involved in the final log density, is that you still call the chain() methods on those vari instances.

It sounds like you’re presupposing we get a single static graph for a model rather than build up our graphs dynamically as we do now. Otherwise, graph algorithms and compiler optimizations on the expression graph are probably not worthwhile compared to just wasting the arithmetic expressions to evaluate them.

The big savings when you have static evaluation is that you can do things like sparsity analysis in Jacobians. We do that implicitly in our representations of functions—for example, we don’t store the whole Jacobian when multiplying matrices. If each matrix is N \times N, the Jacobian will have N^4 entries. Instead, even the naive autodiff algorithm will store 2N pointers per output node, for a total of 2N^3 rather than N^4, but a more efficient algorithm (which I believe we have now and if we don’t we will when we push the rest of the adjoint-Jacobian stuff through) is to store the input and output matrices and just do everything lazily—that’s only 2N^2 storage.

What’s a sea of nodes? Isn’t the expression graph already a graph? You mean a bunch of things that associate like pieces that increment the log density?

Code generation for autodiff gets tricky if you allow dynamic control. If it’s just a static graph, it’s much easier.

I’m also not sure what you mean by parameter-dependent AST. It sounds like you mean what we’ve been calling dynamic.

After all that, I still don’t understand what LMS is.

Does any of this relate to how Adept does expression unfolding? That only works up to an assignment, not up to a whole graph, unless we had more auto-type stuff.

In general, there’s a ton of work in the autodiff literature on checkpointing algorithms where you can do subgraph analysis independently. This can help with scalability as Aki and Dan are proposing to do for nested autodiff eval for GP covariance functions (that is, do autodiff one entry at a time, then save results rather than building big sum of covariance functions).

Might be fastest to talk this through at the office - will you be around Monday afternoon?

One insight from the Demystifying paper is that you can define your own conditional(s) and generate code for both branches (loops based on parameters would need to be turned into recursive function calls). Once you can do that (and generate some AST or computation graph representation that you can do optimizations over), it seems like you can just do it once per model instead of once per leapfrog step.