An idea that came up in this paper is to use Lightweight Modular Staging (LMS) as a technique to cause our autodiff system to create an AST that we can perform a variety of optimizations on (some of which sort of come for free with the LMS approach), and even potentially JIT-compile parts of the gradient pass that don’t depend on parameters. I’m going to try to briefly summarize a couple of these ideas here and give pointers for how they could be implemented in our current Math library (leaving most of the existing code alone, I believe). Not to bury the lede, but I think this might be huge for performance.
LMS uses operator overloading to substitute a fake representation of a type in for the real thing and keep track of symbolic operations over the type (Sound familiar? This is what our autodiff does already with var
). But instead of pushing pointers onto a stack, it creates an AST for all of the operations that are performed on the fake type. This AST then has some nice properties:
- Higher level language abstractions from C++ are removed here and only low-level code remains
- There are some optimizations that come “for free” (e.g. dead code elimination)
- it is easier to encode other optimizations at a higher level (and there exist easier algorithms for e.g. code motion that use nesting instead of more complicated dataflow analysis)
They also recommend switching from an AST to a “sea of nodes” graph-based intermediate representation, which has some additional benefits such as fairly common sub-expression elimination (CSE) and global value numbering (GVN). More in the Optimizatons link below.
At the end of the day, we can take this AST and either interpret it or use it to generate native code. If we don’t change the tree based on parameters (something we can check statically), we only need to do this once and can re-use the results every leapfrog step. Even if our AST is parameter-dependent, it is likely worth performing at least some of these optimizations on every leapfrog step.
Thoughts?
CC @Matthijs @Bob_Carpenter @cqfd
Bibliography: