Adjoint task force for ODEs

I’d like to take a proper stab at implementing the adjoint method in Stan. This past thread presents a discussion on the topic. Two references: CVODES user guide and the discrete ODE adjoint method reviews the continuous case.

From a discussion with @bbbales2, the main bottleneck is that we cannot in Stan do neste autodiff inside the reverse-mode pass, which we would need to do here. What it would it take to enable this feature, without pre-allocating an excessive amount of memory for the autodiff stack?

@bbbales2, @wds15, @betanalpha, @yizhang, @Bob_Carpenter

7 Likes

We really need this. Yes!

Just one thought: the ad tape is stored as a thread local pointer…which you can change!

So while the reverse pass runs, you can swap out the currently active ad tape pointer for a new instance. Then you do your nested ad, collect the result and swap back the old pointer. It’s a bit dirty, but could work.

I am happy to help…though one option is really to integrate with the amici project. Their devs are very friendly and they even have second order ready as I understood. Moreover, they code gen all the derivatives symbolically and already use sparseness… lots of work done there was my impression.

3 Likes

I would like to see this happen as well!

The adjoint solve stuff for RK45 would be trickier than CVODES. CVODES is just switching some flags. RK45 would require some thought.

If we could get this sorted out for the next release (which is pretty do-able), then we could get the adjoint stuff for the release 6 months from now.

I don’t think we need rk45 with adjoint. Let’s concentrate on cvodes and turn on that switch. That’s enough work already.

Having this in 6 month from now in Stan would be really cool. Maybe even with Parameter pack support?

What worries me is that we would have to pass two functions to cvodes…the ode rhs and the cost function. Though there was some way which would not require to pass in the cost function as I recall from what @bbbales2 said.

I want to get that done in the next release. Code is basically there (for both math and @rok_cesnovar has a prototype for stanc3 too), but I assume it’ll take a while to feed it through the system.

lp is our cost function so I continue to assert we don’t need to worry about this :D.

1 Like

Yes, that’s a valid assumption. I am not yet clear on how you do that, but I guess you somehow abuse that AD tape for representing the function.

Just a few Q in this regard:

  • Does what you have in mind also work for hierarchical models like population pharmacokinetic models where I have many patients and for each patient I am solving an ODE?
  • While it is correct that lp is the cost function… this does include the priors and all that. These are irrelevant for the ODE cost function. I mean, a problem we have by design in Stan is that we do not really know what the data likelihood term is. Will this incur overhead when using your approach?

It interfaces with the autodiff stack the same way as the current ODE solver does.

It just accepts adjoints and its output and spits out different adjoints at the input.

The difference is in the reverse pass with the current ODE stuff we’re doing the equivalent of a matrix vector multiply. With the new stuff we’ll solve an ODE to compute the same thing.

1 Like

You don’t want to implement checkpointing scheme yourself. I suggest we drop rk45 + adjoint.

It does sound annoying.

But I guess before this we need to be able to do nested gradients in reverse mode. I guess, does someone want to volunteer to take a first pass at how to do nested gradients in reverse mode?

The sooner we know our options the better. I can look at this maybe next week.

I may have screwed up and vastly overestimated this problem lol: https://github.com/stan-dev/math/pull/1856

I went and investigated it today. If that’s the fix, then uuhh, sorry about delaying adjoint sensitivities a couple years. Bit of an oversight.

5 Likes

Great, I see the PR has been merged! If my understanding is correct, we can now run autodiff inside the chain() method, which means doing operations on the adjoint w.r.t to the target distribution.

So how does this play out for the ODE solver? During the forward pass, we solve the regular ODE system and get the solution u(t); and during the reverse pass, we solve the adjoint system. The problem is that the systems are coupled, meaning I need u(t) to solve the adjoint system. Does this mean we need to solve the regular system both during the forward and reverse pass?

ps (but let’s save it for another post): we can also improve the algebraic solver using this.

The checkpointing is just magically handled by CVODES internally. It handles the memory, and then we put the ODE on the special stack that gets destructors called.

As far as what Stan needs to implement, there’s like a big checklist on page 139 here: https://computing.llnl.gov/sites/default/files/public/cvs_guide.pdf .

It shouldn’t be too hard. Most of those things are optional. If you want to work on it, work from this branch of the ODEs: Variadic argument lists for ODEs by bbbales2 · Pull Request #1641 · stan-dev/math · GitHub . I’m about to go through that pull and add docs/tests for everything. I’ll do whatever is left of adjoints afterwards.

For clarity, the checkpointing is the thing where CVODES saves the forward solution in a special place so that it can use it during the reverse pass.

1 Like

Go for it. You seem more familiar with the code than I am. I can (help) review the PR.

1 Like

FYI, Ivan Yashchuk at Aalto is working on Stan + PETSc for PDEs. I asked him to post something about that here.

2 Likes

Exchanged email with him last week, he’s doing a great job setting up the solver. Looking forward to hearing from him, as I wonder the kind of PDE problem he’s considering.

1 Like