On adjoint sensitivity of ODE/PDE

Regarding adjoint sensitivity of ODE discussed in today’s meeting.

Like @Bob_Carpenter pointed out, the concept is just like forward AD vs reverse AD. The gist of adjoint sensitivity is that given a function u constrained by

F(u, t;\theta)=0

we are interested in some functional of u

J(u; \theta) = \int g(u(\theta))

and its sensitivity regarding \theta

dJ/d\theta.

An example would be u is the trajectory of a particle’s movement, and J is the energy dissipated along the way, from time t_0 to t_1. One way to find adjoint is through optimal control perspective, where J and the constraint in the constrained optimization problem

\min_{\theta} J(u;\theta)\\ u \text{ solves constraint } F(u,t;\theta) = 0

are used to construct Lagrangian

L=J + \lambda^TF

Then the adjoint is just the Lagrange multiplier \lambda that satisfies

F_u^T\lambda = -J_u

After solving for \lambda, one can calculate the sensitivity

dJ/d\theta = \lambda^T F_{\theta}

From the last equation, we can interpret adjoint as a measure of the sensitivity of the constraint, i.e., how much F_{\theta} contributes to the total change of J.

The advantage of using adjoint is that its solution is independent of the the size of parameter \theta, so that if \dim(\theta) >> \dim(J), we’ll gain from solving only once the adjoint system.

In the context of ODE, the constraint F would be the actually ODE to be solved, only this time \lambda would be time dependent, just like u, and we need solve both the prime and the adjoint ODEs. This is actually easier than it sounds, because the Jacobian of adjoint system F_u^T is the transpose of the Jacobian of the prime system, so the two ODEs share same property of stability, and in general we can use a same scheme to treat them.

CVODES/IDAS are able to solve both the forward prime problem and the backward adjoint problem. Where Stan’s AD comes in is when we need Jacobian F_u and J_u for adjoint ODE. In particular, we don’t explicitly need F_u but only F_u^T\lambda, the directional derivative in \lambda direction, this can be solved directly in AD. As a matter of fact, CVODES/IDAS provide an AD interface for this purpose.

For PDE the idea is the same, except this time the evaluation of F_u^\lambda is much more complex. Because now unlike in ODE, F is not naturally in R^n, with n the size of ODE, but in certain function space. There’s a lot effort goes into discretize F into a finite-dimensional version F_h, before all the AD magit can take place. This is why code like ADIFOR takes another approach of AD, by transforming the entire process of F\rightarrow F_h into an AD version, so we can go ahead calculating F_{hu}^T\lambda. The bottom line is that for PDE there is no easy way to get F_u out of box by using Stan Math on certain functor, and we have to rely on external PDE solvers for that. This is behind the logic of forward_pde. The function expects a user-supplied function calculating J(u(\theta)) and dJ/d\theta, assuming u is a solution from external ODE/PDE solvers. Another related item is adj_jac_apply by @Bob_Carpenter and @bbbales2, where user-supplied F_u can be incorporated into AD chain easily. I’m working using adj_jac_apply as backend of forward_pde.

1 Like

Thanks Yi! I’m excited about this, as I really think it can speed up our ODE fitting. I’m happy to help out wherever I can.

I’m going to read more over the weekend then maybe look at the Stan code with @syclik next week. Do you have any good resourced or papers you’d suggest?

Also here’s the old thread I referenced in the meeting.

Maybe @bbbales2 can talk about what was in the way in his last try. To me the implementation is not much different from forward sensitivity for ODE/DAEs. CVODES manual gives you a good start. Could you talk about the exact model and target function? I wonder what form your target function J is.

I’m gonna go ahead and @wds15 to bring him into this thread cause he’ll want to be involved in this too. This is a follow up from the discussion in the meeting yesterday.

That itself was a follow up from me offhandedly mentioning to @Bob_Carpenter that the ODE integrators weren’t working well for me in a certain case.

My goal in talking about this stuff was mostly to just highlight what I see as opportunities for someone to do things better. I’m not speaking with authority here so if you see flaws point em’ out :P.

An exact model that I am currently working with a buddy on is here (actually I omitted a couple sections but the big stuff is there):

functions {
  real[] f(real t, real[] c, real[] theta, real[] x_r, int[] x_i) {
    real D = theta[1];
    real u0 = theta[2];
    real dx = x_r[1];
    real cinf = x_r[2];
    real Cc = x_r[3];
    int N = size(c);
    real u = u0 * (1 - c[1]);
    real dcdx0 = u * (c[1] - Cc) / D;
    real f[N];
    
    f[1] = D * ((c[2] - c[1]) / dx - dcdx0) / (0.5 * dx) - u * dcdx0;
    for (i in 2:(N - 1)) {
      f[i] = D * (c[i + 1] - 2 * c[i] + c[i - 1]) / (dx * dx) - u * (c[i + 1] - c[i - 1]) / (2 * dx);
    }
    f[N] = D * (cinf - 2 * c[N] + c[N - 1]) / (dx * dx) - u * (cinf - c[N - 1]) / (2 * dx);

    return f;
  }
}

data {
  int N;
  real x[N];
  real y[N];
  real cinf;
  real Cc;
  int M;
  real dt;
  real dx;
  int S;
}

parameters {
  real<lower = 0.0> D;
  real<lower = 0.0> u0;
  real<lower = 0.0> sigma;
}

transformed parameters {
  real theta[2] = { D, u0 };
  real ymean[S];
  {
    vector[S] y_ = to_vector(y0);
    for(i in 1:M) {
      y_ = y_ + dt * to_vector(f(0.0, to_array_1d(y_), theta, x_r, x_i));
    }
    ymean = to_array_1d(y_);
  }
  //real yhat[N] = integrate_ode_rk45(f, y0, 0.0, { tf }, theta, x_r, x_i, 1e-8, 1e-8, 2000)[1,];
}

model {
  D ~ normal(26.0, 1.0);
  //u0 ~ normal(0.1, 0.1);
  
  y ~ normal(ymean, sigma);
}

I can’t share data for it, but it’s a 1D diffusion advection, method of lines discretization, 2 parameters, and about 50 states. It’s pretty unglamorous please don’t judge me :P.

If you look at that model you’ll notice we’re just using forward Euler to do the integration. Integrating using the rk45 or the bdf integrators in Stan with all sorts of tolerances is super slow (aka I wasn’t patient enough to let it finish, at least 10x slower). Forward Euler takes about a minute.

First of all, lemme get some notation. The forward sensitivity equations we’re solving are:

y' = f
s' = f_y s + f_p

where

s = \frac{\partial y}{\partial p}

are our sensitivities. We also have initial conditions for this.

So, just to be clear, for each parameter we want the sensitivity with respect to, we add N equations (where N is the number of states in our ODE).

Anyway, one thing that bothered me about the current ODE solvers that make it seem hopeless for that model is that, because there are 50 states to the ODE (N = 50) we gotta do 50 reverse mode autodiffs to build the Jacobian f_y. And in reality because this is embedded in an implicit solver inside CVODES, for every timestep we might build that Jacobian many times.

Now something I got wrong in the meeting was saying that we could get rid of the reverse mode autodiffs and replace them with a forward mode autodiff. f_y s can be done with one forward mode autodiff, but f_p still requires either one forward mode autodiff per parameter, or one reverse mode autodiff per state of the ODE (and if you’re doing reverse mode autodiffs, you might as well compute the whole f_y as well). For my problem, obviously forward mode is faster because I do not have many parameters. Forward mode can also be faster cause you don’t have to mess with an autodiff stack, but it’s not a clear, knock it outta the ballpark win like I said in the meeting. (@syclik so I take back what I told you yesterday – I don’t have an explanation for that experiment you did with the tests)

Plot twist! The right hand side of the adjoint sensitivity equation requires us to compute (eq: 2.22: https://computation.llnl.gov/sites/default/files/public/cvs_guide.pdf):

f_y^T u = u^T f_y

which is just a single reverse mode pass (fill up our adjoints u in the direction we want to go, and chain away!)! So maybe this does get rid of the scaling problem on the sensitivity ODE rhs?

I really think it would be cool to tackle this rhs scaling issue. It would really make the desire to implement custom ODE rhs Jacobians a lot smaller (I think there are a few of us that want these).

Now, on to the other question, which I’m going to turn into a failed attempt to explain How Adjoint Sensitivity Works. I fail. Seriously, but this is my attempt at least to motivate it.

So something I want to clarify here that got bounced around in the last adjoint sensitivity thread is that the J function in our case is the target of our Stan model.

So the output of our ODE factors into our target something like:

target += normal_lpdf(ode_solution, sigma);

It could also be something like:

tmp = arbitrary_stan_code(ode_solution);
target += normal_lpdf(tmp, sigma);

This can be a drop in replacement for the current ODE integrators. I wanna emphasize that!

All the normal formulations adjoint sensitivity analysis make J is an integral of some loss function, but it can be a simpler thing too. It can just be a function of the outputs of the ODE, which is how we build our loss functions (log densities) in Stan.

I like the eqs. here: http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf , but I’m having trouble re-deriving the adjoint sensitivity stuff myself, but I’m gonna keep working on it.

The best way I’ve come up with to motivate myself about what is happening is to look at the forward and adjoint sensitivity problems for an already-discretized system.

So if we have the eqs:

u' = f(u)

And we discretize them with forward euler:

\frac{u_n - u_{n - 1}}{h} = f(u_{n - 1})

And rearrange:

u_n = u_{n - 1} + h f(u_{n - 1})

Where h is out timestep. So we can look at the derivatives of

\frac{\partial u_n}{\partial p} = \frac{\partial u_{n - 1}}{\partial p} + h f_{u}(u_{n - 1}) \frac{\partial u_{n - 1}}{\partial p}

(there would be an extra eq. if f was a function of p that I didn’t include). So this says the same thing as before for the continuous case. If I want to compute sensitivities, I add N equations for each parameter I want sensitivities with respect to (where N is the number of states in the ODE).

So if we were writing reverse mode autodiff in Stan, we’d take the values \frac{\partial u_n}{\partial p} at the end of the forward stepping and inject them into a precomputed_gradients call.

Now, what does the adjoint sensitivity problem look like for the discretized case?

Well, we’ll build it up just like we’d do for any other function we wanted to use reverse mode autodiff on in Stan.

Let’s pretend the update

u_n = u_{n - 1} + h f(u_{n - 1})

is actually just a function mapping old states to new states:

u_n = g(u_{n - 1})

If we want to do reverse mode autodiff on a function, we need the Jacobian

\frac{\partial u_n}{\partial u_{n - 1}} = g_u(u_{n - 1})

If step n was in fact the last step in our forward stepping of our discrete system, and we took those values and computed a log density (call this L), then when we run the reverse mode autodiff eventually we’ll have the values:

\frac{\partial L}{\partial u_n}

And if we want to propagate this back a timestep?

\frac{\partial L}{\partial u_{n - 1}} = (\frac{\partial L}{\partial u_n})^T g_u(u_{n - 1})

And again?

\frac{\partial L}{\partial u_{n - 2}} = (\frac{\partial L}{\partial u_{n - 1}})^T g_u(u_{n - 2})

And we can keep doing that all the way back through, the operators g_u(u_{n - i}) being the chains of the i^\text{th} forward step.

The takeaway is we solve the sensitivity problem without adding on a bunch of equations by taking advantage of the fact that at the end of the day we only need the gradients of one thing (the log density) with respect to a bunch of other things.

In my bad understanding, the continuous forward sensitivity problem corresponds to that forward discrete problem, and then there’s this continuous adjoint sensitivity problem that corresponds to the reverse mode discrete problem.

Alright this post is big enough and probably contains enough falsehoods already, but lemme list the two downsides to adjoint sensitivity analysis:

  1. You need to integrate an ODE every time you call chain (in forward mode you have all the sensitivities cached)

  2. The reverse integration is a little tricky, and for every point in time you use the state of the ODE to increment your log density, you have to restart your integrator. This can be a burden.

And to answer your question Yi, another thing that was in my way is I didn’t understand how the chainable_alloc varis worked (so there’s lots of memory issues I describe in that old thread). Now I do! Hoooray!

Adjoint won’t help you with a model of 2 parameters for a MOL discretized PDE, if like you said your target is

target += normal_lpdf(ode_solution, sigma);

Because in this case your J is essentially the vector of your ODE state, so \dim(J)>>\dim(\theta). Forward sensitivity should be better.

Also, note that if you are using Stan’s CVODES lib to fit this MOL model, the performance is suboptimal because current ODE solver uses a dense solver, while the MOL system is sparse.

That’s not true though.

J is the log density.

Here’s the simple test code I was running before when I rigged this up: ODE example · GitHub

You’re right, miss-read it for multivariate normal.

@bbbales2, @arya, are you doing central difference with your advection term in above code. Have you verified the solution? Central difference is not a good idea for advection, you need upwind.

@yizhang Loool, thanks for that. I just did central cause I didn’t know what to do. I remembered from taking PDE classes one or the other could go unstable depending on the direction of my flow. I will send a note to my buddy.

The posterior predictives look like the data and the parameters seem physical, which is encouraging. But as far as actual verification, I don’t trust it as much as I want to. I wouldn’t build things based on the parameters we get.

Before send it out to inference, you’ll want to verify the numerical solver first. In this case there’s almost certainly some bias added because of the central difference scheme. Even though the outcome may “seem physical”.

1 Like

Hi Ben!

I have been staring at the CVODES manual quite a bit to absorb the adjoint stuff. What I honestly still don’t get how to deal with time… anyway (you solved that, so it gotta work somehow).

But you are absolutely right in that an adjoint approach would be much better. The problem is that we would need to define a cost function which is not easily doable until we have lambda’s in Stan. Once lambda’s are in Stan we should definitely code up an adjoint ODE solving thing… it will be much much faster whenever there are many states or/and many parameters, but we actually only interested in a few derivatives of the cost function which grabs some states from the ODEs and the users data.

In the meantime you could try to resurrect my code which I did last year: https://github.com/stan-dev/rstan/tree/rstanode-proto/rstanode/integrate_parallel

What I have done there is to code up automatically generated analytic Jacobians. This will speed up your problem massively and you are then implicitly taking advantage of sparsity to some extend (since the Jacobian is initialized to zero and you need to put in the non-zeros). The code used to work with stan-math from a year ago (so it relies on ode_system). It shouldn’t be too hard to get it going again, I think/hope.

So in short I think we really want at some point adjoint ODE stuff whenever Stan is ready for it in terms of the language. I did not expect that it would ever be possible, but since lambda’s and closures are in the air we may actually get there.

The other interesting avenue is exploring sparsity, but that’s not clear how that would land in Stan. Maybe with the Stan 4 refactor we can easily introduce a suitable sparse type… who knows.

Cheers,
Sebastian

@wds15 Looool that is some cool code @ the integrate parallel stuff. I will put that in my back pocket.

But the cost function as a limitation thing is incorrect. This sits seamlessly in the autodiff stack. This is my code from before where I did this: https://gist.github.com/bbbales2/5049c99674e34c3796738ff11f3890ca#file-test2-cpp-L73

The cost function is the log density.

My aim back then was to make ODE integration a double only operation. Doing so allowed me to openMP parallelize things. This functionality of the code is now obsolete.

I see… but how does this help us? We need to feed the cost function into CVODES. Sounds like I need to read your code closely, since you are right in that the cost function is implicitly defined by the AD tree and the way we sweep over it. Nice idea.

I need to work on my understanding of these adjoint equations to say the least haha. But I think the problem with their presentation in most contexts is that they are over-complicated for our use case.

The loss function is often written something like (G being our loss) G(y_T, p) = \int_0^T g(y, t, p) dt because I think those sorts of loss functions are used where this stuff is originally formulated.

However, you can also do the derivation with a simple function of the final state of the system, G(y_T, p) = f(y_T, p). (this is noted in the article https://epubs.siam.org/doi/abs/10.1137/S1064827501380630?journalCode=sjoce3)

So what’s weird about that is that you only get to use the last output in your f. Think about an ODE solve where you produce N outputs as N little ODE solves each with 1 outputs. So you accumulate adjoints at the output of one, use the adjoint equation to propagate it back, accumulate some more adjoints, propagate back, etc.

This highlights the real drawback of adjoint sensitivity analysis in that it breaks the reverse solve into N little integrations. Restarting the adjoint ODE solve can be a performance problem (as noted here: https://arxiv.org/abs/1606.04406 ).

I gotta figure out a better way to derive this for myself so my explanations aren’t so hand-wavy.

Alright, I had some time to look at the math. We don’t actually have to implement any of the integration stuff. CVODES will do that for us. I just want to be able to stop being super uncertain when I say what is what. There are probably mistakes here. Props for pointing them out.

Things are formulated as Lagrangians, and I’m not sure why. We don’t even set the derivatives equal to zero, so I’m honestly flummoxed about what it does for us. I think a lot of times you might use a Lagrangian like this to prove that solutions actually exist, but I dunno, whatever. Just keep that in mind.

So with boundless confidence, let’s get started!

First we’re going to pick something we want to get sensitivities with respect to. The way all the derivations work is we say have a loss function:

G(y, p) = \int_0^T g(t, y, p) dt

The derivations say we want to minimize this loss function, and show us how to compute the sensitivities \frac{d G}{d p}.

First, maybe we’re implicitly minimizing something when we’re doing this, but it’s not the goal of our sensitivity analysis. So just ignore that for now. I assume this has more to do with existence of solution than anything else.

To figure out how the loss function (G) relates to anything, let’s think about what it is we want to compute. Let’s say in our forward pass we have:

p = \text{our parameters}
y = \text{ode solution, a function of p}
L = \text{log density, a function of y}

So in the reverse pass when we go to call the chain method of our ODE solve we present it with:

\frac{d L}{d y}, \text{the adjoints at the output of the ODE solver}

and we want to compute

\frac{d L}{d p}, \text{the adjoints at the input of the ODE solver}

So it stands to reason that if adjoint sensitivity analysis computes \frac{d G}{d p}, then we somehow want to relate L and G. We also want to make sure that this adjoint sensitivity analysis takes as input \frac{d L}{d y}, cause that’s what we have to initialize it.

For the first thing, G is an integral that is a function of the solution of our ODE. L is not. It is a function of our ODE at a finite set of points.

At this point, we make the simplifying assumption that our log density is a function of the final solution of the ODE. There will be a workaround for this but the simplest derivation allows the log density only to be a function of the final time. This is not a real limitation though.

So how do we do this?

\frac{d}{d T} \frac{d G}{d p} = \frac{d g(T, y(T), p)}{d p}

by (Leibniz’s rule)

If our log density is a function of the final state of the ODE, we can call it g!

So this derivation will work in two stages. First we’ll compute:

\frac{d G}{d p}

And then we’ll compute

\frac{d}{d T} \frac{d G}{d p}

And that’s all we need.

This derivation is taken in bits and pieces from:

  1. https://epubs.siam.org/doi/abs/10.1137/S1064827501380630?journalCode=sjoce3
  2. https://computation.llnl.gov/sites/default/files/public/cvs_guide.pdf
  3. http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf

We write down our Lagrangian (where this comes from, who knows, just take it):

I(y) = G(y) - \int_0^T \lambda^T (\frac{\partial y}{\partial t} - f(y, p)) dt

First thing to note, \frac{\partial y}{\partial t} - f(y, p)) = 0, so this equation actually says:

I(y) = G(y), which might seem weird but it’s necessary.

With a couple simplifying assumptions:

  1. I removed the p dependence from our g (it won’t matter)
  2. Our ODE’s f isn’t a function of time

We take full derivatives with respect to p (noting I(y) = G(y)):

\frac{d I(y, p)}{d p} = \frac{d G}{d p} = \int_0^T g_p \frac{\partial y}{\partial p} dt - \frac{\partial \int_0^T \lambda^T (\frac{\partial y}{\partial t} - f(y, p)) dt}{\partial p}

Move the partial inside the derivative on the right (Leibniz rule, and neither endpoint of the integral is a function of p):

\int_0^T \frac{\partial \lambda^T (\frac{\partial y}{\partial t} - f(y, p))}{\partial p} dt

Apply the chain rule

\int_0^T \frac{\partial \lambda^T}{\partial p} (\frac{\partial y}{\partial t} - f(y, p)) + \lambda^T \frac{\partial (\frac{\partial y}{\partial t} - f(y, p))}{\partial p} dt

Again, along solutions of our ODE, \frac{\partial y}{\partial t} - f(y, p)) = 0 (this is the governing equation of our ODE), so this simplifies:

\int_0^T \lambda^T \frac{\partial (\frac{\partial y}{\partial t} - f(y, p))}{\partial p} dt

So we get:

\frac{d G}{d p} = \int_0^T g_p \frac{\partial y}{\partial p} dt - \int_0^T \lambda^T (\frac{\partial}{\partial t} \frac{\partial y}{\partial p} - f_y(y, p) \frac{\partial y}{\partial p} - f_p) dt

And this is cool! If we could evaluate the equation on the right hand side, then we’d know what \frac{d G}{d p} is (which is nearly what we want). The part we don’t know is \lambda. This is a vector function of time (same number of outputs as ODE). I think how this works is we can choose \lambda to be anything we want it to be. I’m really not sure. It seems like we could choose \lambda = 0 and everything would stop working, but whatever.

Let’s integrate by parts this term:

\int_0^T \lambda^T \frac{\partial}{\partial t} \frac{\partial y}{\partial p} dt

To make it:

\lambda^T \frac{\partial y}{\partial p}|_0^T - \int_0^T \frac{\partial \lambda^T}{\partial t} \frac{\partial y}{\partial p}

Combining everything together and matching terms that contain \frac{\partial y}{\partial p}

\frac{d G}{d p} = \int_0^T (g_y + \frac{\partial \lambda^T}{\partial t} + \lambda^T f_y) \frac{\partial y}{\partial p} dt - \lambda^T \frac{\partial y}{\partial p}|_0^T + \int_0^T \lambda^T f_p dt

So now in the spirit of making the left hand side easy to compute, let’s just pick \lambda such that:

\frac{\partial \lambda^T}{\partial t} = -\lambda^T f_y - g_y

So now we have an ODE for lambda! We still need an initial condition for it, so let’s keep looking.

\frac{d G}{d p} = - \lambda^T \frac{\partial y}{\partial p}|_0^T + \int_0^T \lambda^T f_p dt

Yuck! In that first term, we’d need to evaluate \frac{\partial y}{\partial p} at time T. These are the forward sensitivities. We don’t wanna do that, because we want to avoid solving the forward sensitivity problem.

If we take

\lambda(T) = 0

That gets rid of this term, and gives us an initial condition for our ODE!

We can now actually evaluate:

\frac{d G}{d p} = - \lambda^T \frac{\partial y}{\partial p}|_{t = 0} + \int_0^T \lambda^T f_p dt

We have an equation for \lambda, so we can solve for that at any time t \in [0, T] (integrating back in time is A-Okay here, apparently). The thing on the left are values at t = 0, which we’re cool with. The thing on the right is just an integral of stuff we know.

We could actually write a second ODE:

\frac{\partial r}{\partial t} = -\lambda^T f_p
r(T) = 0

If you work out the math, this means r(0) = \int_0^T \lambda^T f_p dt.

Now we can evaluate \frac{d G}{d p}! That’s exciting, but we really wanted \frac{d}{d T}\frac{d G}{d p}, so let’s keep going.

\frac{d}{d T}\frac{d G}{d p} = \frac{d - \lambda^T \frac{\partial y}{\partial p}|_{t = 0} + \int_0^T \lambda^T f_p dt}{d T}

The trick here is that the initial conditions of \lambda are a function of T (we say \lambda(T) = 0), so \frac{\partial \lambda}{\partial T} isn’t immediately zero.

We use Leibniz’s rule again:

\frac{d}{d T}\frac{d G}{d p} = -\frac{d \lambda^T}{d T} \frac{\partial y}{\partial p}|_{t = 0} + \lambda^T f_p |_{t = T} + \int_0^T \frac{\partial \lambda^T}{\partial T} f_p dt

Again, noting \lambda(T) = 0

\frac{d g}{d p} = -\frac{d \lambda^T}{d T} \frac{\partial y}{\partial p}|_{t = 0} + \int_0^T \frac{\partial \lambda^T}{\partial T} f_p dt

The trick now is to use the substitution u = \frac{\partial \lambda}{\partial T} noting full and partial derivatives of \lambda with respect to T are the same (I think).

So we can do something similar to what we did with the equations for r above and say:

\frac{\partial u^T}{\partial t} = -u^T f_p

Which simplifies the equation above to:

\frac{d g}{d p} = -u^T \frac{\partial y}{\partial p}|_{t = 0} + u^T(T) - u^T(0)

And we actually already have initial conditions for u!

Remember u = \frac{\partial \lambda}{\partial T}

Previously we had:

\frac{\partial \lambda^T}{\partial t} = -\lambda^T f_y - g_y

Evaluate this at T we have:

\frac{\partial \lambda^T}{\partial t}_{t = T} = -g_y

So

u^T(T) = -g_y

Now if we go waaaaaay back, remember that \frac{d g}{d y} are adjoints of the outputs of the ODE solver! These will be supplied to the ODE solver by a standard reverse mode autodiff pass! Please ignore that I’ve apparently screwed up partials vs. full derivatives again.

So putting it all together, what we need to do is:

  1. Integrate the forward ODE from t = 0 to t = T (so we know what y is)

  2. Integrate the adjoint ODE:

\frac{\partial u^T}{\partial t} = -u^T f_p
u^T(T) = -g_y

from t = T to t = 0

  1. Evaluate

\frac{d g}{d p} = -u^T \frac{\partial y}{\partial p}|_{t = 0} + u^T(T) - u^T(0)

And we have the sensitivities to pass up the chain! Note we didn’t use the \lambda and r ODEs in this actual final evaluation.

What is really cool about this compared to forward mode sensitivity is:

  1. We do not have the number of states * number of parameters blow up in the number of equations to solve
  2. The terms -u^T f_p that we need to evaluate on the right hand side of the adjoint ODE equations require only one reverse mode evaluation
  3. Our forward mode ODE can actually just use double evaluations of the right hand side.
  4. You can hook forward and adjoint sensitivity analysis together to get higher order derivatives (CVODES does this too). Not really necessary for us right now

And I told you that this derivation would only work if you only use the ODE output at the final time. To extend this to where you can use the output at many times in your log density, think of using the output at multiple different time values as doing a bunch of tiny, single time output ODEs. We’d also need to get a set of equations that give us the sensitivities \frac{d g}{d y}. Turns out this isn’t too bad (check http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf for the solutions – the equations are probably labeled differently – pay attention more to the inputs and outputs).

4 Likes

What a post! Cool. Just to comment on the above for now: This double only thing only holds for non-stiff ODE’s. For the stiff ones, we need the Jacobian wrt to the states of the ODE functor.

… now my second pass in reading…

For the Jacobian of the implicit timestep solve – this is another thing I didn’t know till recently, but you don’t need the actual Jacobian for this either. If you’re using an iterative solver in the Newton iterations for the implicit solve, you end up needing to compute Jacobian vector products, which you can get with finite differences (CVODES will do this automagically if you don’t give it the Jacobian) or fvars. It’s pretty cool. There are a lot of moving pieces there I didn’t describe well though.

If you end up coding something up to test things are working, use the FATODE eqs. Solid chance I screwed something up here hahaha.

Well that was anti-climactic. Two problems!

Adjoint ODE branch is here: https://github.com/bbbales2/math/tree/feature/adjoint-ode

For the Lorenz test problem (3 states, 3 parameters), forward sensitivity is about 20-30% faster than adjoint sensitivity. rk45 blows them both away cause it’s a non-stiff problem, but that and the harmonic oscillator were the easiest things to play with since they were already coded up.

So the forward sensitivity problem is going to have 12 equations, and adjoint will have 3 going forward and 6 backward. Not a huge difference, but I’d still have liked to see adjoint eek out the win because of the simplicity of evaluating the right hand side of the adjoint sensitivity equations with autodiff. For the harmonic oscillator, adjoint was about 50% slower or something (2 states, 1 parameter).

The problem seems to be though that the backwards integration is just harder than the forward integration. It takes around 2x the right hand side evaluations for the backwards integration as the forward sensitivity does. So even if the autodiff is more efficient, it’s just gotta do more work. (edit: each autodiff right hand side for the reverse mode should take 3x less work than for the fwd in this case)

There’s a lot going on in Sundials though. I tried to wiggle the common knobs, and these results seemed stable, but I really didn’t get a good sense of what is going on.

Now we could make this problem bigger or choose a different system and tip the scales in the favor of adjoint, but I’m not super keen on that. My original tests were with like 20 states and 20 parameters. In that case the forward sensitivity problem would have 420 equations and the adjoint one would have 20 forward and 40 in reverse. But I’d really like it to work better for tiny problems cause I think that’s more representative of what people work on (and One Size Fits All solutions are really nice).

Second of all, there’s a memory problem in how adjoint sensitivity uses the autodiff. During the chain of its vari it needs to call nested autodiff. This can cause segfaults and such if the stack tries to reallocate itself (cause the nested autodiff is growing) while it is busy reading itself (for the outer autodiff).

Na… one size fits all does not really work here, I think. Though a key advantage of the adjoint stuff is that if from the N states we only need a single state to end up in the likelihood (which is often the case in PK/PD problems), then the adjoint is much more efficient - even if the system is small. At least this is how I understood it.

Ups… that’s a bummer. However, you are probably using the system way outside it’s originally intended use case. Would using forward mode to get the sensitivities of the ODE RHS solve this?