Alright, I had some time to look at the math. We don’t actually have to implement any of the integration stuff. CVODES will do that for us. I just want to be able to stop being super uncertain when I say what is what. There are probably mistakes here. Props for pointing them out.
Things are formulated as Lagrangians, and I’m not sure why. We don’t even set the derivatives equal to zero, so I’m honestly flummoxed about what it does for us. I think a lot of times you might use a Lagrangian like this to prove that solutions actually exist, but I dunno, whatever. Just keep that in mind.
So with boundless confidence, let’s get started!
First we’re going to pick something we want to get sensitivities with respect to. The way all the derivations work is we say have a loss function:
G(y, p) = \int_0^T g(t, y, p) dt
The derivations say we want to minimize this loss function, and show us how to compute the sensitivities \frac{d G}{d p}.
First, maybe we’re implicitly minimizing something when we’re doing this, but it’s not the goal of our sensitivity analysis. So just ignore that for now. I assume this has more to do with existence of solution than anything else.
To figure out how the loss function (G) relates to anything, let’s think about what it is we want to compute. Let’s say in our forward pass we have:
p = \text{our parameters}
y = \text{ode solution, a function of p}
L = \text{log density, a function of y}
So in the reverse pass when we go to call the chain method of our ODE solve we present it with:
\frac{d L}{d y}, \text{the adjoints at the output of the ODE solver}
and we want to compute
\frac{d L}{d p}, \text{the adjoints at the input of the ODE solver}
So it stands to reason that if adjoint sensitivity analysis computes \frac{d G}{d p}, then we somehow want to relate L and G. We also want to make sure that this adjoint sensitivity analysis takes as input \frac{d L}{d y}, cause that’s what we have to initialize it.
For the first thing, G is an integral that is a function of the solution of our ODE. L is not. It is a function of our ODE at a finite set of points.
At this point, we make the simplifying assumption that our log density is a function of the final solution of the ODE. There will be a workaround for this but the simplest derivation allows the log density only to be a function of the final time. This is not a real limitation though.
So how do we do this?
\frac{d}{d T} \frac{d G}{d p} = \frac{d g(T, y(T), p)}{d p}
by (Leibniz’s rule)
If our log density is a function of the final state of the ODE, we can call it g!
So this derivation will work in two stages. First we’ll compute:
\frac{d G}{d p}
And then we’ll compute
\frac{d}{d T} \frac{d G}{d p}
And that’s all we need.
This derivation is taken in bits and pieces from:
- https://epubs.siam.org/doi/abs/10.1137/S1064827501380630?journalCode=sjoce3
- https://computation.llnl.gov/sites/default/files/public/cvs_guide.pdf
- http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf
We write down our Lagrangian (where this comes from, who knows, just take it):
I(y) = G(y) - \int_0^T \lambda^T (\frac{\partial y}{\partial t} - f(y, p)) dt
First thing to note, \frac{\partial y}{\partial t} - f(y, p)) = 0, so this equation actually says:
I(y) = G(y), which might seem weird but it’s necessary.
With a couple simplifying assumptions:
- I removed the p dependence from our g (it won’t matter)
- Our ODE’s f isn’t a function of time
We take full derivatives with respect to p (noting I(y) = G(y)):
\frac{d I(y, p)}{d p} = \frac{d G}{d p} = \int_0^T g_p \frac{\partial y}{\partial p} dt - \frac{\partial \int_0^T \lambda^T (\frac{\partial y}{\partial t} - f(y, p)) dt}{\partial p}
Move the partial inside the derivative on the right (Leibniz rule, and neither endpoint of the integral is a function of p):
\int_0^T \frac{\partial \lambda^T (\frac{\partial y}{\partial t} - f(y, p))}{\partial p} dt
Apply the chain rule
\int_0^T \frac{\partial \lambda^T}{\partial p} (\frac{\partial y}{\partial t} - f(y, p)) + \lambda^T \frac{\partial (\frac{\partial y}{\partial t} - f(y, p))}{\partial p} dt
Again, along solutions of our ODE, \frac{\partial y}{\partial t} - f(y, p)) = 0 (this is the governing equation of our ODE), so this simplifies:
\int_0^T \lambda^T \frac{\partial (\frac{\partial y}{\partial t} - f(y, p))}{\partial p} dt
So we get:
\frac{d G}{d p} = \int_0^T g_p \frac{\partial y}{\partial p} dt - \int_0^T \lambda^T (\frac{\partial}{\partial t} \frac{\partial y}{\partial p} - f_y(y, p) \frac{\partial y}{\partial p} - f_p) dt
And this is cool! If we could evaluate the equation on the right hand side, then we’d know what \frac{d G}{d p} is (which is nearly what we want). The part we don’t know is \lambda. This is a vector function of time (same number of outputs as ODE). I think how this works is we can choose \lambda to be anything we want it to be. I’m really not sure. It seems like we could choose \lambda = 0 and everything would stop working, but whatever.
Let’s integrate by parts this term:
\int_0^T \lambda^T \frac{\partial}{\partial t} \frac{\partial y}{\partial p} dt
To make it:
\lambda^T \frac{\partial y}{\partial p}|_0^T - \int_0^T \frac{\partial \lambda^T}{\partial t} \frac{\partial y}{\partial p}
Combining everything together and matching terms that contain \frac{\partial y}{\partial p}
\frac{d G}{d p} = \int_0^T (g_y + \frac{\partial \lambda^T}{\partial t} + \lambda^T f_y) \frac{\partial y}{\partial p} dt - \lambda^T \frac{\partial y}{\partial p}|_0^T + \int_0^T \lambda^T f_p dt
So now in the spirit of making the left hand side easy to compute, let’s just pick \lambda such that:
\frac{\partial \lambda^T}{\partial t} = -\lambda^T f_y - g_y
So now we have an ODE for lambda! We still need an initial condition for it, so let’s keep looking.
\frac{d G}{d p} = - \lambda^T \frac{\partial y}{\partial p}|_0^T + \int_0^T \lambda^T f_p dt
Yuck! In that first term, we’d need to evaluate \frac{\partial y}{\partial p} at time T. These are the forward sensitivities. We don’t wanna do that, because we want to avoid solving the forward sensitivity problem.
If we take
\lambda(T) = 0
That gets rid of this term, and gives us an initial condition for our ODE!
We can now actually evaluate:
\frac{d G}{d p} = - \lambda^T \frac{\partial y}{\partial p}|_{t = 0} + \int_0^T \lambda^T f_p dt
We have an equation for \lambda, so we can solve for that at any time t \in [0, T] (integrating back in time is A-Okay here, apparently). The thing on the left are values at t = 0, which we’re cool with. The thing on the right is just an integral of stuff we know.
We could actually write a second ODE:
\frac{\partial r}{\partial t} = -\lambda^T f_p
r(T) = 0
If you work out the math, this means r(0) = \int_0^T \lambda^T f_p dt.
Now we can evaluate \frac{d G}{d p}! That’s exciting, but we really wanted \frac{d}{d T}\frac{d G}{d p}, so let’s keep going.
\frac{d}{d T}\frac{d G}{d p} = \frac{d - \lambda^T \frac{\partial y}{\partial p}|_{t = 0} + \int_0^T \lambda^T f_p dt}{d T}
The trick here is that the initial conditions of \lambda are a function of T (we say \lambda(T) = 0), so \frac{\partial \lambda}{\partial T} isn’t immediately zero.
We use Leibniz’s rule again:
\frac{d}{d T}\frac{d G}{d p} = -\frac{d \lambda^T}{d T} \frac{\partial y}{\partial p}|_{t = 0} + \lambda^T f_p |_{t = T} + \int_0^T \frac{\partial \lambda^T}{\partial T} f_p dt
Again, noting \lambda(T) = 0
\frac{d g}{d p} = -\frac{d \lambda^T}{d T} \frac{\partial y}{\partial p}|_{t = 0} + \int_0^T \frac{\partial \lambda^T}{\partial T} f_p dt
The trick now is to use the substitution u = \frac{\partial \lambda}{\partial T} noting full and partial derivatives of \lambda with respect to T are the same (I think).
So we can do something similar to what we did with the equations for r above and say:
\frac{\partial u^T}{\partial t} = -u^T f_p
Which simplifies the equation above to:
\frac{d g}{d p} = -u^T \frac{\partial y}{\partial p}|_{t = 0} + u^T(T) - u^T(0)
And we actually already have initial conditions for u!
Remember u = \frac{\partial \lambda}{\partial T}
Previously we had:
\frac{\partial \lambda^T}{\partial t} = -\lambda^T f_y - g_y
Evaluate this at T we have:
\frac{\partial \lambda^T}{\partial t}_{t = T} = -g_y
So
u^T(T) = -g_y
Now if we go waaaaaay back, remember that \frac{d g}{d y} are adjoints of the outputs of the ODE solver! These will be supplied to the ODE solver by a standard reverse mode autodiff pass! Please ignore that I’ve apparently screwed up partials vs. full derivatives again.
So putting it all together, what we need to do is:
-
Integrate the forward ODE from t = 0 to t = T (so we know what y is)
-
Integrate the adjoint ODE:
\frac{\partial u^T}{\partial t} = -u^T f_p
u^T(T) = -g_y
from t = T to t = 0
- Evaluate
\frac{d g}{d p} = -u^T \frac{\partial y}{\partial p}|_{t = 0} + u^T(T) - u^T(0)
And we have the sensitivities to pass up the chain! Note we didn’t use the \lambda and r ODEs in this actual final evaluation.
What is really cool about this compared to forward mode sensitivity is:
- We do not have the number of states * number of parameters blow up in the number of equations to solve
- The terms -u^T f_p that we need to evaluate on the right hand side of the adjoint ODE equations require only one reverse mode evaluation
- Our forward mode ODE can actually just use double evaluations of the right hand side.
- You can hook forward and adjoint sensitivity analysis together to get higher order derivatives (CVODES does this too). Not really necessary for us right now
And I told you that this derivation would only work if you only use the ODE output at the final time. To extend this to where you can use the output at many times in your log density, think of using the output at multiple different time values as doing a bunch of tiny, single time output ODEs. We’d also need to get a set of equations that give us the sensitivities \frac{d g}{d y}. Turns out this isn’t too bad (check http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf for the solutions – the equations are probably labeled differently – pay attention more to the inputs and outputs).