User defined gradients updates

@andrjohns: What have you been thinking about that? I saw your link to the PR, but there are 30+ files touched and it’s in the math lib, so I’m not sure what the language changes will look like.

I’m still stuck at how to do this without having to represent Jacobians explicitly, which is a dealbreaker. For efficiency, we’ll need to express reverse-mode derivatives as adjoint-vector products. And then there’s the issue of how to define multi-argument functions efficiently. Did you have a design document somewhere?

I haven’t gotten to the language stage yet, I’m still ironing out the c++ implementation.

The basic implementation is that a tuple of arguments, a functor for the value, and a tuple of functors for the gradients are specified. The functors for the gradients take the value of the function as the first argument, followed by all other input arguments.

Using the hypot function as an example, this looks like

template <typename T1, typename T2>
inline auto hypot(const T1& a, const T2& b) {
  // Functor for calculating the return value
  auto val_fun = [&](auto&& x, auto&& y) {
    using std::hypot;
    return hypot(x, y);
  };

  // Functors for calculating gradient wrt each input
  // Where d/dx = x / hypot(x, y)  &  d/dy = y / hypot(x, y)
  // elt_divide() function allows for both matrix & scalar inputs here
  auto grad_fun_a
      = [&](auto&& val, auto&& x, auto&& y) { return elt_divide(x, val); };
  auto grad_fun_b
      = [&](auto&& val, auto&& x, auto&& y) { return elt_divide(y, val); };

  // Forward tuple of input arguments, functor for value,
  // and tuple of gradient functors
  return function_gradients(std::forward_as_tuple(a, b),
                            std::forward<decltype(val_fun)>(val_fun),
                            std::forward_as_tuple(grad_fun_a, grad_fun_b));
}

The function_gradients functor then iterates over the input arguments and gradient functor tuples to calculate the gradient for non-primitive inputs and update the .adj_ or .d_ members where needed.

Through some tuple black-magic, this compiles out to analytic gradients for var, fvar<T>, Matrix<var>, and var<Matrix>. I still need to do some testing and development with more complex matrix functions to see how the framework will need to generalise/change for those though

2 Likes

This is getting way off topic from the language meeting, but I don’t know how to move bits of threads.

Thanks, @andrjohns. The obstacle to explicitly rendering derivatives in higher-dimensions is that it’ll blow out memory. Consider what happens with a simple matrix function like matrix inverse(matrix). The input is N x N, the output is N x N, so the derivatives are (N x N) x (N x N), each element of which requires a pointer in memory and pointer chasing during autodiff. That is, we wind up with an \mathcal{O}(N^4) algorithm in both time and memory.

Instead, the way these are implemented for a function y = f(a, b) are as individual updates which conceptually do this

a.adj += y.adj * dy/da;
b.adj += y.adj * dy/db;

But the key thing is that there are usually much more efficient ways to update a.adj and b.adj than the explicit multiplies by the Jacobians. The best source I know to read about this is:

For example, Giles has the result that if C = A^{-1}, then the update rule is

A.adjoint += -C' * C.adjoint * C;

which you can see is going to be only an \mathcal{O}(N^3) operation in time and even more critically only an \mathcal{O}(N^2) operation in memory.

The issue is then how to have users express these adjoint-Jacobin products. The way we’d have to do user-specified gradients is to have them write the adjoint updates for the operands based on the adjoint of the result. Suppose we have a user-defined function

T foo(U a, V b);

Then we can specify adjoint-Jacobin products in the following form, where we pass in everything we know about the arguments and adjoints (as in @andrjohns’s example) and the result is added to the relevant adjoint (a in the first case, b in the second).

U adj_jac_prod_foo_1(T y_adjoint, T y, U a, V b);

V adj_jac_prod_foo_2(T y_adjoint, T y, U a, V b);

For example, if we were coding our own inverse function, we’d use

matrix my_inv(matrix x);

matrix adj_jac_prod_my_inv_1(matrix y_adjoint, matrix y, matrix x) {
  return -y' * y_adjoint * y;
}

This is fine for one-argument functions, but if there is more than one argument for which we need gradients, we’d like to reuse computations across the calls. So it’s tempting to roll the two functions together (or would be if we had tuple returns). Always computing both derivatives is also problematic because in the program, one of them might be data, so we wouldn’t need a derivative.

The proposal from @andrjohns would work for forward- and reverse-mode. The proposal I’m making doesn’t help with forward-mode, where we’d like to just write down the tangent rules directly (see Giles, again, for how that’d look optimally, then we’d need to do something similar to the above in code).

moved

1 Like

Thanks Bob. I’m not entirely sure I follow about the issues with memory though.

If I have a look at the reverse-mode specification for inverse, the gradients are given by:

  reverse_pass_callback([res, res_val, arena_m]() mutable {
    arena_m.adj() -= res_val.transpose() * res.adj_op() * res_val.transpose();
  });

Which is the adjoint-Jacobian product that you’re after. This is also what the proposed user-defined gradients would aim to compile to. So it looks like for matrix parameters, the gradient functors will also need to take the adjoint as an argument.

Once I add the adjoint as an input, this means that the equivalent user-defined specification for the inverse function would be:

template <typename MatT>
inline auto inverse(const MatT& mat) {
  // Functor for calculating the return value
  auto val_fun = [&](auto&& m) {
    return inverse(m);
  };

  // Functor for calculating gradient wrt each input
  auto grad_fun_m = [&](auto&& val, auto&& adj, auto&& m) { 
    return val.transpose() * adj * val.transpose();
  };

  // Forward tuple of input argument, functor for value,
  // and tuple of gradient functor
  return function_gradients(std::forward_as_tuple(m),
                            std::forward<decltype(val_fun)>(val_fun),
                            std::forward_as_tuple(grad_fun_m));
}

Does that cover what you were thinking about or am I off-track?

1 Like

Just a quick note, since these lambdas are made locally use std::move instead of std::forward into the function gradients functor. And i wouldnt use [&], once those functions go on the reverse pass callback stack the things the lambda captures by reference may not exist anymore

1 Like

Ah of course, good catch

Sorry for not responding earlier. The problem with memory is that you can’t explicitly represent the Jacobian without blowing out memory, but you can easily compute an adjoint-Jacobian product without blowing out memory. Anything that is forced to return a gradient won’t scale to multivariate functions.

The problem I was having is not with your code, but with your code comments, which say “Functor for calculating gradient wrt each input”, despite the fact that the functor never calculates gradients. This is the kind of code comment I’m always talking about when I say code comments can make code less understandable (everyone always asks me for examples but they’re everywhere, including in things I write—I’m not trying to pick on this example that hasn’t been code reviewed and is only in Discourse!). The other two code comments are the kind of redundant code comment I think is also harmful for reading code because it just gets in the way. I like Google’s coding standard, which says to rename, use functional units, and only use code comments as a last resort.

What I would’ve liked to have seen doc for is the lambda functions—what are the three arguments to grad_fun_m (which I would suggest renaming to adj_jac_prod_prod, or even better, just adj_jac_prod because it’s clear it’s a function because it’s a lambda). I’d also recommend overloading function_gradients for the unary and binary cases so you don’t need to forward as tuples in each of the definitions, but that’s just low-level syntactic sugare—it’s clear here what’s going on with that function.

Oh I see what you mean, I completely agree. I’ll be working up a design doc over the weekend to put the implementation in more concrete terms, so that should be a bit more helpful than my afterthought-comments