This is getting way off topic from the language meeting, but I don’t know how to move bits of threads.
Thanks, @andrjohns. The obstacle to explicitly rendering derivatives in higher-dimensions is that it’ll blow out memory. Consider what happens with a simple matrix function like matrix inverse(matrix)
. The input is N x N
, the output is N x N
, so the derivatives are (N x N) x (N x N)
, each element of which requires a pointer in memory and pointer chasing during autodiff. That is, we wind up with an \mathcal{O}(N^4) algorithm in both time and memory.
Instead, the way these are implemented for a function y = f(a, b)
are as individual updates which conceptually do this
a.adj += y.adj * dy/da;
b.adj += y.adj * dy/db;
But the key thing is that there are usually much more efficient ways to update a.adj
and b.adj
than the explicit multiplies by the Jacobians. The best source I know to read about this is:
For example, Giles has the result that if C = A^{-1}, then the update rule is
A.adjoint += -C' * C.adjoint * C;
which you can see is going to be only an \mathcal{O}(N^3) operation in time and even more critically only an \mathcal{O}(N^2) operation in memory.
The issue is then how to have users express these adjoint-Jacobin products. The way we’d have to do user-specified gradients is to have them write the adjoint updates for the operands based on the adjoint of the result. Suppose we have a user-defined function
T foo(U a, V b);
Then we can specify adjoint-Jacobin products in the following form, where we pass in everything we know about the arguments and adjoints (as in @andrjohns’s example) and the result is added to the relevant adjoint (a
in the first case, b
in the second).
U adj_jac_prod_foo_1(T y_adjoint, T y, U a, V b);
V adj_jac_prod_foo_2(T y_adjoint, T y, U a, V b);
For example, if we were coding our own inverse function, we’d use
matrix my_inv(matrix x);
matrix adj_jac_prod_my_inv_1(matrix y_adjoint, matrix y, matrix x) {
return -y' * y_adjoint * y;
}
This is fine for one-argument functions, but if there is more than one argument for which we need gradients, we’d like to reuse computations across the calls. So it’s tempting to roll the two functions together (or would be if we had tuple returns). Always computing both derivatives is also problematic because in the program, one of them might be data, so we wouldn’t need a derivative.
The proposal from @andrjohns would work for forward- and reverse-mode. The proposal I’m making doesn’t help with forward-mode, where we’d like to just write down the tangent rules directly (see Giles, again, for how that’d look optimally, then we’d need to do something similar to the above in code).