I did some more thinking about this. I agree that with the current way vars work (all the autodiff is embedded in scalar elements) the full Jacobian across the ODE system is required.
The way the adjoint sensitivity analysis works – it doesn’t just chain (for lack of better word) each output vari one at a time. It takes the adjoints attached to all the outputs and simultaneously integrates back and computes the contributions to the adjoints at its inputs in one go. What doesn’t conceptually work here with the scalar autodiff variables is that there might be no guarantee that all the adjoints at the output of the ODE solver are ready to be integrated back, and doing each of them individually is probably not any more efficient than the forward sensitivity analysis.
However, I think the way the Math library uses a stack and because all the varis created in the ODE solver get dumped on the stack at the same time, this condition is satisfied. It’s my understanding the library assumes that by the time it gets to a point in the stack, all of the things attached at the output have already been processed. Because all the ODE varis go on at the same place in the stack relative to everything else, when any single chain gets called, all the chains are ready to get called.
So couldn’t each of the little varis (each corresponding to one output from the ODE) initiate the chaining for everything? The trick there is to just design the varis in a way that only the first one called has any effect.
I coded this up here: https://github.com/bbbales2/adjoint/blob/master/stan_test.cpp
It’s just a linear 2D ODE that runs around in a circle (I tested a non-linear 1D adjoint as well outside this for what it’s worth). If we call the ode y' = f
, I hard coded the necessary Jacobians of f
. The output is ten evenly spaced samples of the first state. There is one parameter (that is the angular frequency squared). I just used forward Euler and did the integrator myself to keep the code simple looking.
The Melicher paper @betanalpha found was useful. I also liked this one (esp. for the clarity of eq. 2.4): http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf .
There’s the detail that we’re taking output at many different time points of the ODE. This still only requires one backwards ODE solve (though Melicher points out it’s a little finicky). If you look at how I did the reverse integration you might find the lines:
if((i + 1) % (N / P) == 0) {
l[0] += neighbors[p]->adj_;
p -= 1;
}
a little weird. What is happening there is that as we integrate back in time, as we pass outputs we absorb their adjoints into the bigger ode adjoint solve.
This is how the Melicher paper describes the backwards integration, though they may have implemented it a little differently. It’s a little weird to think about and resolve with the adjoint equations as they are given in 2.4 of the Zhang paper, but just think about doing a long ODE solve with many intermittent outputs as doing a sequence of small ODE solves where you extract the output in intermediate steps and reinitialized a new ODE solver before moving on. Then think about solving the big adjoint problem as solving a bunch of smaller adjoint problems in reverse (where you set the initial conditions of each little adjoint ODE problem as the sum of the contributions from calculations on its output as well as the contributions from any ODE solves that followed it).
Anyway I’m probably still missing something about how this could be used in Stan, but I’ve got answers to my questions so I’m happy either way :D. Figured this might be useful to someone looking around the forums either way.