So adjoint sensitivity stuff is up. Here is a pull req. on my math repo (that makes it easy to look at the changes): https://github.com/bbbales2/math/pull/2
The main magic is in stan/math/rev/mat/fun/build_ode_varis.hpp (which takes the output of the ODE solver and wraps it up in vars, similar to the decouple_ode_state stuff that was there before) and stan/math/rev/mat/fun/ode_vari.hpp (which is the vari that coordinates the adjoint calculation). The vari in ode_vari is the thing that needs to have a destructor or something (to clean up the CVODES memory).
Edit: The math behind this is eq. 2.4 of http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf
Something to keep in mind is that even though the ODE produces number_of_time_steps * number_of_states output (and the same number of vars and varis), the adjoint solve is only hooked up to one of these varis. The rest are allocated on the non-chaining stack, this is how all the adjoints are done together.
This is the little example code I was using to test: https://gist.github.com/bbbales2/5049c99674e34c3796738ff11f3890ca
Here’s the output of timing it for the adjoint sensitivity analysis (those numbers are the derivatives of the sum of all the ODE output states with respect to the different diffusion constants):
bbales2@frog:~/math-adjoint$ source build2.sh && time ./test2
...
-23681.4 6228.88 3667.17 1196.06 246.698 34.8322 3.559 0.274326 0.0164524 0.000761025 -5.89514e-17 -2.77336e-05 2.80233e-05 0.00114006 0.0185708 0.231901 2.21294 15.7028 80.4073 289.296 722.426
real 0m0.591s
If I time it with the forward mode sensitivity analysis:
bbales2@frog:~/math-adjoint-ref$ source build2.sh && time ./test2
...
-23681.4 6228.88 3667.17 1196.06 246.698 34.8322 3.559 0.274326 0.0164524 0.000761025 1.72543e-15 -2.77336e-05 2.80233e-05 0.00114006 0.0185708 0.231901 2.21294 15.7028 80.4073 289.296 722.426
real 0m2.194s
Performance takes a significant hit if we need to take intermediate steps. With 10 outputs, adjoint sensitivity takes 1.2s~ and forward still takes 2.1s.
All the tests except these three in test/unit/math/rev/mat/functor/integrate_ode_bdf_rev_test.cpp work: https://github.com/bbbales2/math/pull/2/files#diff-8f1971e1a585b016e51b051c4f473acbR125
These are the hacks I use to keep the var_stack vector from reallocating and causing segfaults when I embed reverse mode autodiff in the chain bit of a reverse mode autodiff: https://github.com/bbbales2/math/pull/2/files#diff-130c5a75cc427d7d41715e9fca8281f4
Here is a Google perf tools sampling of the example program from above running:
Most of the time is spent computing derivatives of the forward ODE RHS with respect to state and theta. I did some experiments, at least for this problem, using fvars to compute the Jacobians is faster (I did these in isolation of the actual ODE – still haven’t incorporated it back in).