I’m trying to generate a sequence of derivatives of function using stan math. That is, given a function f(x)
, to generate f'(x), f''(x), f'''(x), ...
up to a certain order n
. What’s the best way of doing this with autodiff?
The Nomad manual is the place to look for how higher order derivatives work: https://github.com/stan-dev/nomad/tree/master/manual. You gotta compile this manually.
Suffice to say I’m not actually sure, but the last time I had a question about it @betanalpha sent me there and that document was useful.
The Nomad manual discusses some of the theory of automatic differentiation, but not the details of how Stan currently implements higher-order autodiff.
One important commonality to any implementation, however, is that in higher-order autodiff you can’t compute all of the partial derivatives in one go – you have to continuously sweep back and forth to build them up in terms of various directional derivatives. For example, second-order reverse mode lets you compute Hessian-vector products in one forward and one reverse sweep, but you still need N sweeps to fill out all of the second-order partials. If you’re careful, though, you will get the gradient components as a side effect of the higher-order sweeps.
In general you have to step back and figure out what higher-order differential operator you want to operate, figure out how to implement it with autodiff sweeps, then see what lower-order differential operators fall out in those calculations. There’s not much theory around this yet – it’s some wacky geometry that I’m still trying to wrap my head around – but I worked out the implementation of second and third-order operators, all of the lower-order stuff that falls out, in Section 1.3.4 of the Nomad manual.
First a warning: our higher order autodiff has not been thoroughly tested other than in the basic C++ functions and arithmetic.
What’s most efficient depends on N and M in f:R^N -> R^M.
If everything’s unary, then
var
computes a gradients.
fvar<var>
computes gradients of a derivative.
fvar<fvar<var>>
the gradients of a second derivative
and so on.
Thanks a lot guys.