Part I: Lambda-based Closures
Autodiff is where Stan spends 95% of its execution time, so I always devote cycles to figuring out how to make it faster and easier to use.
Motivations
-
The underlying theory is really about continuations—the reverse mode pass to apply the chain rule is just a bunch of functions queued up in the forward pass. The current autodiff system does this implicitly, but I hadn’t really thought about it formally until talking to Barak Pearlmutter and @Matthijs
-
For years @rtrangucci and others who’ve looked closely at the matrix autodiff have been clamoring for matrix autodiff without having to explicitly copy out an Eigen matrix (which entails heap allocation)
-
C++11 makes a lot of this stuff a lot easier than it used to be. We’ll start with a literal continuation-based implementation, then I’ll refactor the closures into custom, memory tight objects that could be used for serious autodiff.
I’m going to make a lot of references to how things are implemented now. If you’re not familiar, we have an arXiv paper on Stan’s autodiff that explains how pretty much everything’s implemented now.
The Code
So let’s get started with the implementation. I’m going to start with some nice C++11 tricks, then over a few posts I’ll refactor into something with custom implementations. For a bit of a spoiler, here’s a post on lambdas explaining why std::function
types are so heavy. Here’s the code. I’ll go back through it repeating peaces to highlight, but I thought it’d be good if you have it all in one place:
#include <cmath>
#include <functional>
#include <limits>
#include <iostream>
#include <vector>
namespace agrad {
typedef std::vector<double> adj_t;
typedef std::function<void(adj_t&)> chain_t;
std::vector<chain_t> stack_;
std::size_t next_idx_ = 0;
inline std::size_t next_idx() {
return next_idx_++;
}
struct var {
double val_;
int idx_;
var(const var& v) : val_(v.val_), idx_(v.idx_) { }
var(double val, int idx) : val_(val), idx_(idx) { }
var(double val) : val_(val), idx_(next_idx()) { }
var() : val_(std::numeric_limits<double>::quiet_NaN()), idx_(next_idx()) { }
};
struct matrix_var {
MatrixXd val_;
int** idx_;
}
inline var operator+(const var& x1, const var& x2) {
var y(x1.val_ + x2.val_);
stack_.emplace_back([=](adj_t& adj) {
adj[x1.idx_] += adj[y.idx_];
adj[x2.idx_] += adj[y.idx_];
});
return y;
}
inline var operator+(const var& x1, double x2) {
var y(x1.val_ + x2);
stack_.emplace_back([=](adj_t& adj) {
adj[x1.idx_] += adj[y.idx_];
});
return y;
}
inline var operator+(double x1, const var& x2) {
var y(x1 + x2.val_);
stack_.emplace_back([=](adj_t& adj) {
adj[x2.idx_] += adj[y.idx_];
});
return y;
}
inline var operator*(const var& x1, const var& x2) {
var y(x1.val_ * x2.val_);
stack_.emplace_back([=](adj_t& adj) {
adj[x1.idx_] += x2.val_ * adj[y.idx_];
adj[x2.idx_] += x1.val_ * adj[y.idx_];
});
return y;
}
inline var operator*(const var& x1, double x2) {
var y(x1.val_ * x2);
stack_.emplace_back([=](adj_t& adj) {
adj[x1.idx_] += x2 * adj[y.idx_];
});
return y;
}
inline var operator*(double x1, const var& x2) {
var y(x1 * x2.val_);
stack_.emplace_back([=](adj_t& adj) {
adj[x2.idx_] += x1 * adj[y.idx_];
});
return y;
}
inline var exp(const var& x) {
var y(std::exp(x.val_));
stack_.emplace_back([=](adj_t& adj) {
adj[x.idx_] += y.val_ * adj[y.idx_];
});
return y;
}
std::vector<double> grad(const var& y, const std::vector<var>& x) {
std::vector<double> adj(y.idx_ + 1, 0.0);
adj[y.idx_] = 1;
for (auto chain_f = stack_.crbegin(); chain_f != stack_.crend(); ++chain_f)
(*chain_f)(adj);
std::vector<double> dy_dx(x.size());
for (std::size_t i = 0; i < x.size(); ++i)
dy_dx[i] = adj[x[i].idx_];
return dy_dx;
}
} // namespace agrad
int main() {
stack_.clear();
next_idx = 0;
using agrad::var;
var x1 = 10.3;
var x2 = -1.1;
std::vector<var> x = { x1, x2 };
var y = x1 * x2 * 2 + 7;
std::vector<double> dy_dx = agrad::grad(y, x);
std::cout << "y = " << y.val_ << std::endl;
for (std::size_t i = 0; i < x.size(); ++i)
std::cout << "dy / dx[" << (i + 1) << "] = " << dy_dx[i] << std::endl;
return 0;
}
Here’s how it runs (the c++1y
flag says to use C++11, C++14 [edit: and C++17]).
$ clang++ -std=c++1y ad.cpp
$ ./a.out
y = -15.66
dy / dx[1] = -2.2
dy / dx[2] = 20.6
Autodiff type: var
Everything other than the example application gets dropped into the namespace agrad
, which is the original namespace I used for Stan’s first autodiff (it was short for “automatic gradients”); we’ve since put everything in stan::math
.
The top-level type we use is quite a bit different than our current pointer-to-implementation.
struct var {
double val_;
int idx_;
var(const var& v) : val_(v.val_), idx_(v.idx_) { }
var(double val, int idx) : val_(val), idx_(idx) { }
var(double val) : val_(val), idx_(next_idx()) { }
var() : val_(std::numeric_limits<double>::quiet_NaN()), idx_(next_idx()) { }
};
It stores two values, a value and an index (with padding to 8-byte alignment, it’s still going to be a 16-byte object); later, I’ll typedef out the index types and value types for flexibility. We almost always pass these var
by reference and they all live on the function call stack (not on the heap), so the extra weight here isn’t a problem.
The main difference to what we’re doing is that the value is stored locally. This means we don’t have to chase a pointer to get a value. It also means we won’t have to allocate values in the arena-based memory unless we need them in the reverse pass. Finally, rather than allocating vari
and pointing to it, we only maintain and index. That index structure is maintained with a global variable next_idx
and a wrapper function to get the next one and increment:
std::size_t next_idx_ = 0; // global!
inline std::size_t next_idx() {
return next_idx_++;
}
From the constructor, you can see that when we construct a var
from a double, we’ll allocate a new index:
var(double val) : val_(val), idx_(next_idx()) { }
There’s no memory being allocated for that index. Everything remains local to the var
, which is great for increasing memory locality and reducing pointer chasing and subsequent cache misses (which can cost dozens of floating-point operations).
One must be very careful with globals, translation units, etc., but I’m not going to focus on that here. Also, we can optionally wrap up our global variables in one or more thread-local variables; the cost is a bit of a synchronization slowdown and the benefit is full multi-threading with independent autodiff (that makes things like parallel map as was just released in 2.18 easy).
The autodiff stack
Conceptually, reverse mode autodiff works by evaluating an expression, and for each subexpression, pushing a continuation to propagate the chain rule in reverse onto a stack. These are then visited in reverse order to propagate the chain rule from the final value expression down to the input arguments.
The first implementation follows the theory and literally creates a stack of continuations. We need to include the <functional>
header in order to define the appropriate types.
typedef std::vector<double> adj_t;
typedef std::function<void(adj_t&)> chain_t;
std::vector<chain_t> stack_;
The type adj_t
is for sequence of adjoints, indexed by expression id (as held in a var
). The The type chain_t
is the type of the continuations held in the autodiff stack, and stack_
is the global variable holding the continuation stack.
The continuations are functions that work on a vector of adjoints. Each expression has an index, and that expression’s adjoint value is in the adjoint stack (which is type adj_t
and created when derivatives are needed rather than in advance). More formally, we use the C++11 functional lib std::function
to provide a type for the continuations, std::function<void(adj_t&)>
. Thus each continuation applies to a vector of adjoints to propagate the chain rule one or more steps.
To run more than one autodiff, stack_.clear()
will have to be called. We’ve swept details like that under the rug to keep this example simple. We’ll ramp up to much cleaner implementations by the third post.
It’s probably easiest to see how autodiff works for a non-trivial unary function. Here’s the full definition of exp
:
inline var exp(const var& x) {
var y(std::exp(x.val_));
stack_.emplace_back([=](adj_t& adj) {
adj[x.idx_] += y.val_ * adj[y.idx_];
});
return y;
}
First, the variable y
is constructed with the value given by exponentiating the value of the argument. Then the fun starts. Rather than having to define new vari
classes for each autodiff, we’re going to use the magic of lambdas. First, note that stack_.emplace_back(...)
is going to construct an object of the type std::function<void(adj_t&)>
(the type of elements held by stack_
) using the memory in the stack. This conveniently avoids a copy (which is heavy for std::function
). Now let’s break down the lambda expression that will be emplaced,
[=](adj_t& adj) {
adj[x.idx_] += y.val_ * adj[y.idx_];
}
The initial [=]
controls the closure behavior of the lambda, with =
denoting pass-by-value. The values being bound by the closure are x.idx
, y.val_
and y.idx_
. Copies of these are made in the closure object with their current value. This creates a function that applies to an adj_t
, which is the vector of adjoint values. The action here says to update the adjoint of x
by adding the value of y
(i.e., exp(x)
) times the adjoint of y
. The adjoints are held in the vector of scalar values in adj
and are indexed by the variable index. When the chain rule’s done, we wind up with each adjoint being the derivative of the final value with respect to the expression represented by the index of the adjoint.
Let’s consider one more example, operator+
. It has three instantiations, one of which we consider here:
inline var operator+(const var& x1, double x2) {
var y(x1.val_ + x2);
stack_.emplace_back([=](adj_t& adj) {
adj[x1.idx_] += adj[y.idx_];
});
return y;
}
Now there are two arguments, but only one is a var
. The value is calculating in the obvious way to construct a new var
(with a freshly generated index). Then the continuation is defined to update the adjoint of x
with the adjoint of the result y
. As for exponentiation, x1.idx
and y.idx
are stored by value.
Applying the chain rule
This is pretty easy now with a function that takes a result variable y
and computes derivatives w.r.t. the variables in x
, returning them as a vector of floating point values.
std::vector<double> grad(const var& y, const std::vector<var>& x) {
std::vector<double> adj(y.idx_ + 1, 0.0);
adj[y.idx_] = 1;
for (auto chain_f = stack_.crbegin(); chain_f != stack_.crend(); ++chain_f)
(*chain_f)(adj);
std::vector<double> dy_dx(x.size());
for (std::size_t i = 0; i < x.size(); ++i)
dy_dx[i] = adj[x[i].idx_];
return dy_dx;
}
We start by initializing the adjoint vector of double
values to a large enough size to use y.idx_
and assign all the values to 0.0
. Because we take y
as the result, each adjoint slot will hold the derivative of y
with respect to the expression with its index. So we start by setting the derivative of y
with respect to itself to be 1
.
Then we just walk over the stack of continuations from last added down to the first (the iterator crbegin()
is the top of the stack and crend()
is the bottom of the stack. Each updates the adjoints by propagating the chain rule from an expression down to its operands; it’s as easy as applying the function *chain_f
to the stack of adjoints, (*chain_f)(adj)
.
Next we allocate a standard vector for the results and copy the adjoints for the x
values into dy_dx
to return.
Using it
Here’s the main function again, which is defined outside of the agrad
namespace.
int main() {
using agrad::var;
stack_.clear();
next_idx = 0;
var x1 = 10.3;
var x2 = -1.1;
std::vector<var> x = { x1, x2 };
var y = x1 * x2 * 2 + 7;
std::vector<double> dy_dx = agrad::grad(y, x);
std::cout << "y = " << y.val_ << std::endl;
for (std::size_t i = 0; i < x.size(); ++i)
std::cout << "dy / dx[" << (i + 1) << "] = " << dy_dx[i] << std::endl;
return 0;
}
First we clear the global variables (this will not free memory in the stack, just reset the size to zero, which is the behavior we want for Stan). Then se set the next index for the next expression to zero.
Next, we just set two variables of type var
to 10.3 and -1.1 respectively. There is no explicit assignment, but there is an implicit var(double)
constructor which will get used. Then we group the variables into the x
vector to pass to grad
; this uses C++11 brace initializers.
Next, we have the expression we evaluate, var y = x1 * x2 * 2 + 7;
. This finds the overloaded operator*
and operator+
implementations defined in the agrad
namespace through argument-dependent lookup (ADL).
Finally, we calculate the gradient by calling agrad::grad
and assign the result to dy_dx
fo printing.
Et voila.
Warning!!!
This works and hopefully easy to follow. But it’s going to be slow and a memory hog.
The culprit is in the way that std::function
works. There is a really great post explaining why this approach isn’t tenable (I hot linked the relevant section, but the whole thing’s super clear and very worth reading):
- Shahar Mike’s blog post: Under the hood of lambdas and std::function
Next time
We’ll fix this problem by defining our own custom continuations. Then we’ll be back implementing someting like our current vari
, but without any built-in member variables.