On adjoint sensitivity of ODE/PDE

To follow up… if my understand is right, then adjoint sensitivity analysis can take of advantage from using only a few states out of the N states for the log lik. That would be a huge win since many of my models involve many states, but I actually only need just one of them show up in the log lik. So if that can be exploited, then adjoint will kill forward mode in most settings…but we would need a way to allow users saying which states they actually needed.

I had to think about this a while but no, I don’t think there’s a way to simplify the adjoint equations any more.

The way to see this is look at eqs. 2.13 in http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf.

In particular,

\lambda' = -f_y^T \lambda
With \lambda_i(t_F) = (g_i)_y

(g_i)_y in this case are the adjoints of our log density with respect to each of the outputs of the ODE (the y indicates they’re adjoints with respect to part of the solution of the ODE, and the i indicates them individually). And now that I look back the FATODES people say the i index is over output functions, so the two sets of i indices are different so be warned haha.

Some of these may be zero, and so the initial conditions of \lambda at t = t_F may only be sparsely non-zero, but that doesn’t mean that all the gradients, \lambda_i, corresponding to each of those things is zero, so we still have to integrate the full \lambda system.

(All that notation assumes we’re only using the last time point of the ODE in the log density)

Using less output can help with the integration restarts, but I don’t think that was a big problem.

To learn more about what’s going on, I think I’d have to break all this stuff out of Stan and look at it more closely. There’s the inevitable danger that the CVODES people did everything right and there’s nothing to fiddle with, but that’s the risk of delving further haha.

But like, that’s a linear time variant system. That’s one step up from linear time invariant. It just seems like if there’s any system that’s easy to solve, it’s that! And half the equations are literally just an integral (look at u').

@wds15 I threw some more changes into the branch yesterday. Here are notes to go along with them.

  1. What you said about not needing to integrate back all the equations made me realize I was actually including too many things in the reverse ODE. Some of the eqs. are just integrals (again, 2.13 in http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf, you only need to actually solve the \lambda' ode, not the u' ode). So that means that if the forward problem has N states, the adjoint ODE is a N state linear time variant ODE.

  2. There’s various bits of commented code in there now for using the iterative solvers. I also included the full Jacobian calculation if using the linear solvers (which I originally didn’t want to do). None of these things really helped the test case.

  3. It seems like for this problem the adjoint ODE takes a little of 2x the time as the forward ODE to solve (forward ODE in this case being just the ODE solution – not forward sensitivity). This is still surprising to me, but maybe the numerics of the adjoint problem are harder than I’m giving it credit for. The benchmark is a non-stiff chaotic set of ODEs which seems like a terrible pairing for what CVODES is meant to solve.

  4. Either way, BDF is a linear multistep method, which combined with the fact that this is a linear time varying system means that the implicit solve in the ODE is just a linear system. It should just work in 1 Newton iteration. This makes me think that the gears are turning fast, but to keep the tight tolerances the solver is just having to make super tiny timesteps. That’s just the working theory though. It’s not been proven or anything.

  5. I think the next step in this stuff is benchmarking it on a different set of ODEs.

edited for typos