Adjoint sensitivities

Stan uses forward sensitivities in the ODE solver, right? So if you have N states and P parameters, you tack on an extra N * P ODE’s and solve a bigger ODE to get the solutions.

Is there a reason to not do adjoint sensitivity analysis? I see it in the CVODES docs section 2.7 (https://computation.llnl.gov/sites/default/files/public/cvs_guide.pdf). It sounds like it would be useful in the same way reverse mode autodiff is nicer than fwd mode autodiff (derivative of one thing with respect to a lot of parameters vs. derivative of many things with respect to few parameters).

I’m not sure about the technical details of doing that reverse time integration for the adjoint equation. Sounds non-trivial, but since it’s in the CVODES manual I’d hope someone has thought this through.

Was just thinking about ODEs after I saw @wds15 's ode stuff on the blog (http://andrewgelman.com/2017/06/16/speed-parallelizing-stan-using-message-passing-interface-mpi/) (looks cool by the way! I could very much use this)

1 Like

Hi!

Yeah, I have been looking at the adjoint stuff myself. However, the problem is that then you would have to provide the cost function directly which one could possibly work around (squared loss or whatever). The main problem I saw was that it doesn’t quite calculate what we need. That is, the method calculates the integral over that cost function (and the gradients), but we need that per observation and not marginalized. At least this was my take after staring a while at the equations, I could well be wrong.

The adjoint approach would be a huge step forward if it would work.

Just to emphasize, the MPI story is well suited for ODEs, but it should speed up non-ODE problems the very same way, maybe not that efficient.

I think the equivalent of the ‘cost function’ or whatever in our case would just be the log-likelihood, right?

Look at Eq. 3.1 in this paper: https://engineering.ucsb.edu/~cse/Files/adjointI01.pdf

Superscript * in this paper is conjugate transpose. Subscript x is partial derivative with respect to states (f_x^* is the transpose of the Jacobian I think).

I think adjoint sensitivity analysis is something like integrate forward from 0 to T, saving states. It’s followed by a backwards integration of the adjoint equations. Looking at Eq. 3.1 makes me think that backward in time solve would fit nicely in the reverse mode autodiff chain.

If g is the log likelihood, then when the autodiff chains it’s way back to the output of ODE, (dg/dx)(T) is what the autodiff chain has (if that makes any sense, derivative of the log likelihood with respect to the states of the output of the ODE).

dg/dx is the initial condition of the adjoint eq. at time T, integrate that back to time zero, and then you can compute dg/dp from the state of the adjoint ODE (where p are the input parameters) and then chain that along.

I think the downside of the adjoint sensitivity stuff is that you need to save all your intermediate timesteps from the ODE (cause you need to be able to evaluate the Jacobian on your way back through with the adjoint eqs). And maybe your reverse integration with the adjoints uses different time points than your forward integration? The CVODES manual talked about interpolation or something.

Does what I said make sense at all? I’m honestly not that familiar with this stuff. I’ve looked at adjoint sensitivity stuff before, but usually the Math confuses me and I wander away.

That paper also cites this one: http://twister.caps.ou.edu/OBAN2016/Errico_BAMS_1997.pdf (it’s cited in the paper above).

A more appropriate paper is likely https://arxiv.org/abs/1606.04406. My limited understanding is that you don’t get the full Jacobian (which we need for arbitrary autodiff) but rather certain projections. For something like the states going directly into a Gaussian then you only need to compute a quadratic form which can be written as a projection and hence the method is applicable. But it wouldn’t be in general, and so if you wanted to use it then you’d need a special Gaussian_ODE function.

@betanalpha Oooooh, good find. I’ve Googled around a lot on adjoint stuff but I always hit weird optimization problems from like the 80s-90s.

I think an issue with the thing I pointed out is you’d have to solve a different backward ODE for every single different time step you had state information at.

This paper looks like it’s addressing that (rightly so) with one big fancy adjoint computation.

That Eq 15. looks weird, but they say they can handle the little deltas coming along so I’ll believe them.

If they had a different ‘d’ at every time point, wouldn’t that be general enough? Was there an assumption somewhere that locked that down?

Oh I guess a Stan program could mix up outputs of the ODEs at different time points and that wouldn’t fit in there… I’ll have to look at this closer to see what I’m getting wrong.

My intuition is that to recover the full Jacobian you’d have to run the adjoint method multiple times and end up with the same complexity that we have right now. Again, we need the full Jacobian for general autodiff and if you wanted to limit the use of the ODE output to certain cases (like input to a Gaussian likelihood) then you’d have to define a special function.

This is in line with my understanding in that you have to either define a bunch of special functions like gaussian_ode or even better make that cost function also general by some functional. From what I understood the ODE integrator then would have to be a integrate_ode_lpmf as it would give you efficiently the gradient of the cost function wrt to all the parameters. However, even then I am not sure if it works as you get the integral over time of the cost function which is only defined at those measurement time-points.

The backward solve should not be a problem. CVODES does this for you (no idea how, but the manual has the details).

I did some more thinking about this. I agree that with the current way vars work (all the autodiff is embedded in scalar elements) the full Jacobian across the ODE system is required.

The way the adjoint sensitivity analysis works – it doesn’t just chain (for lack of better word) each output vari one at a time. It takes the adjoints attached to all the outputs and simultaneously integrates back and computes the contributions to the adjoints at its inputs in one go. What doesn’t conceptually work here with the scalar autodiff variables is that there might be no guarantee that all the adjoints at the output of the ODE solver are ready to be integrated back, and doing each of them individually is probably not any more efficient than the forward sensitivity analysis.

However, I think the way the Math library uses a stack and because all the varis created in the ODE solver get dumped on the stack at the same time, this condition is satisfied. It’s my understanding the library assumes that by the time it gets to a point in the stack, all of the things attached at the output have already been processed. Because all the ODE varis go on at the same place in the stack relative to everything else, when any single chain gets called, all the chains are ready to get called.

So couldn’t each of the little varis (each corresponding to one output from the ODE) initiate the chaining for everything? The trick there is to just design the varis in a way that only the first one called has any effect.

I coded this up here: https://github.com/bbbales2/adjoint/blob/master/stan_test.cpp

It’s just a linear 2D ODE that runs around in a circle (I tested a non-linear 1D adjoint as well outside this for what it’s worth). If we call the ode y' = f, I hard coded the necessary Jacobians of f. The output is ten evenly spaced samples of the first state. There is one parameter (that is the angular frequency squared). I just used forward Euler and did the integrator myself to keep the code simple looking.

The Melicher paper @betanalpha found was useful. I also liked this one (esp. for the clarity of eq. 2.4): http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf .

There’s the detail that we’re taking output at many different time points of the ODE. This still only requires one backwards ODE solve (though Melicher points out it’s a little finicky). If you look at how I did the reverse integration you might find the lines:

if((i + 1) % (N / P) == 0) {
  l[0] += neighbors[p]->adj_;
  p -= 1;
}

a little weird. What is happening there is that as we integrate back in time, as we pass outputs we absorb their adjoints into the bigger ode adjoint solve.

This is how the Melicher paper describes the backwards integration, though they may have implemented it a little differently. It’s a little weird to think about and resolve with the adjoint equations as they are given in 2.4 of the Zhang paper, but just think about doing a long ODE solve with many intermittent outputs as doing a sequence of small ODE solves where you extract the output in intermediate steps and reinitialized a new ODE solver before moving on. Then think about solving the big adjoint problem as solving a bunch of smaller adjoint problems in reverse (where you set the initial conditions of each little adjoint ODE problem as the sum of the contributions from calculations on its output as well as the contributions from any ODE solves that followed it).

Anyway I’m probably still missing something about how this could be used in Stan, but I’ve got answers to my questions so I’m happy either way :D. Figured this might be useful to someone looking around the forums either way.

Yes, you can do that—this is what we do in some of the matrix operations to cut down on virtual function calls. The underlying memory manager maintains two stacks, one for variables whose chain() method is called and one that isn’t.

So adjoint sensitivity stuff is up. Here is a pull req. on my math repo (that makes it easy to look at the changes): https://github.com/bbbales2/math/pull/2

The main magic is in stan/math/rev/mat/fun/build_ode_varis.hpp (which takes the output of the ODE solver and wraps it up in vars, similar to the decouple_ode_state stuff that was there before) and stan/math/rev/mat/fun/ode_vari.hpp (which is the vari that coordinates the adjoint calculation). The vari in ode_vari is the thing that needs to have a destructor or something (to clean up the CVODES memory).

Edit: The math behind this is eq. 2.4 of http://www.mcs.anl.gov/~hongzh/papers/SISC_FATODE_final.pdf

Something to keep in mind is that even though the ODE produces number_of_time_steps * number_of_states output (and the same number of vars and varis), the adjoint solve is only hooked up to one of these varis. The rest are allocated on the non-chaining stack, this is how all the adjoints are done together.

This is the little example code I was using to test: https://gist.github.com/bbbales2/5049c99674e34c3796738ff11f3890ca

Here’s the output of timing it for the adjoint sensitivity analysis (those numbers are the derivatives of the sum of all the ODE output states with respect to the different diffusion constants):

bbales2@frog:~/math-adjoint$ source build2.sh && time ./test2
...
-23681.4 6228.88 3667.17 1196.06 246.698 34.8322 3.559 0.274326 0.0164524 0.000761025 -5.89514e-17 -2.77336e-05 2.80233e-05 0.00114006 0.0185708 0.231901 2.21294 15.7028 80.4073 289.296 722.426 

real	0m0.591s

If I time it with the forward mode sensitivity analysis:

bbales2@frog:~/math-adjoint-ref$ source build2.sh && time ./test2
...
-23681.4 6228.88 3667.17 1196.06 246.698 34.8322 3.559 0.274326 0.0164524 0.000761025 1.72543e-15 -2.77336e-05 2.80233e-05 0.00114006 0.0185708 0.231901 2.21294 15.7028 80.4073 289.296 722.426 

real	0m2.194s

Performance takes a significant hit if we need to take intermediate steps. With 10 outputs, adjoint sensitivity takes 1.2s~ and forward still takes 2.1s.

All the tests except these three in test/unit/math/rev/mat/functor/integrate_ode_bdf_rev_test.cpp work: https://github.com/bbbales2/math/pull/2/files#diff-8f1971e1a585b016e51b051c4f473acbR125

These are the hacks I use to keep the var_stack vector from reallocating and causing segfaults when I embed reverse mode autodiff in the chain bit of a reverse mode autodiff: https://github.com/bbbales2/math/pull/2/files#diff-130c5a75cc427d7d41715e9fca8281f4

Here is a Google perf tools sampling of the example program from above running:

Most of the time is spent computing derivatives of the forward ODE RHS with respect to state and theta. I did some experiments, at least for this problem, using fvars to compute the Jacobians is faster (I did these in isolation of the actual ODE – still haven’t incorporated it back in).

For the example problem above (one output, timing the whole program w/ 100 ode solves), these are some timing results. There are N states and N + 1 parameters.

N = 20
adjoint: 0.591s, forward: 2.194s

N = 40
adjoint: 2.933s, forward: 13.479s

N = 80 (which is probably bigger than you’d put in a Stan model)
adjoint: 18.040s, forward: 94.428s

Stuff seems to scale similarly, which I don’t think you’d expect. It’s like forward is scaling better than it should and adjoint is scaling worse (edit: forward is an N * (P + 1) order ode to solve, backward is one order N forward ode and one N + P backward ode).

Hi!

I think there is huge potential in the adjoint stuff to seriously speed up ODE models. I haven’t looked into the code yet, but as you mention that you spent lots of time of autodiffing the ODE system… there is a straightforward way to give the system the Jacobian in its analytic form, see here:

Essentially that code defines a partial specialization of the ode_system which allows you to sneak into stan-math an analytic Jacobian of the ODE. Would be interesting to see you performance numbers with that.

What is a bit worrisome is that you say that more outputs lead to worse performance…

In a nutshell what I understood is that the best thing to do with this adjoint stuff is to calculate the gradient wrt to the log-prob contributions which the ODE solution is part of. So everything should go in a single sweep. Exactly this is the difficulty in the stan world as everything needs to be munged together rather than keeping things modular. However, with C++11 and functional elements in Stan this could possibly be managed at some point.

However, honestly… if your Stan program is amendable to MPI parallelism, then you just need to throw hardware at your problem. So whenever you have a natural way to slice your problem (hierarchical model, for example), then you can solve with MPI and sufficient hardware this problem. For ODE problems, the speedup is almost guranteed to be linear in the # of CPUs - 10x speedup?? No problem, just take 10 CPUs and you are done. 20x? Also no issue. Of course, this assumes that you have such massive machines around.

Cool work! I will certainly dive into it at some point.

Sebastian

1 Like

Yeah, it’s pretty easy for me to write out the Jacobians in this case, but then I feel like I’m getting too far away from the abstractions Stan is supposed to provide for general purpose-ness. I very well might end up doing that or some custom c++ stuff in the end though.

In a nutshell what I understood is that the best thing to do with this adjoint stuff is to calculate the gradient wrt to the log-prob contributions which the ODE solution is part of. So everything should go in a single sweep.

Yeah this is exactly what’s happening.

The Jacobian numbers (doing the calculations with forward mode autodiff for both the adjoint and forward sensitivities):

N = 20
adjoint: 0.440s, forward: 2.151s

N = 40
adjoint: 1.906s, forward: 13.556s

N = 80 (which is probably bigger than you’d put in a Stan model)
adjoint: 8.818s, forward: 93.736s

For these things, I removed the is_nan guards in stan/math/fwd/core/fvar.hpp. They’re quite slow, and I don’t think necessary (if an fvar’s val is nan, should be ignoring whatever the d_ is anyway). It’s a fairly substantial performance change with them out.

I didn’t push the code for this, cause it’s not that well organized. But it’s basically this code: Forward mode Jacobian · GitHub replacing the current reverse mode stuff in stan/math/rev/mat/functor/ode_system.hpp (you’d have to include forward mode headers to make it work).

New timings are here (still most time spent in Jacobian calculation):

I completely agree they’re not necessary. One of the things I’ve been meaning to get to with forward-mode is ripping out the NaN hack. The only way it ever got in is that we had zero performance tests in place.

How’d you measure performance?

This is what I use: https://github.com/gperftools/gperftools (it’s really easy to use in Ubuntu: https://github.com/ahorn/benchmarks/wiki/Profiling-with-google-perftools)

Sampler based profiler. Works pretty well. You just load up these custom libraries into your code and run it. It doesn’t seem to really affect performance. The documentation is pretty sparse though so it can sometimes be annoying to remember how to get it to work again.

When I ran it with the NaNs in, the NaN checks just showed up with most samples, so I just removed them and things sped up quite a bit.

Ironically, profilers mess with performance profiling in terms of timing. Try it without the profiler and with -O3 if you really want to see what’s going on.

All the numbers above are without the profiler with -O3. There aren’t any special compiler flags. I think it works by just interrupting the process and writing down the stack trace X times per second (default 100 times).

N = 80, adjoint sensitivity, NaNs removed
8.74s

N = 80, adjoint sensitivity, NaN checks in place
10.59s

This is the file I used to do the benchmarking of fwd vs. reverse mode Jacobian calculations: https://gist.github.com/bbbales2/e3f790a3f5496c0115cd5c64d8daaa3a

It basically computes the Jacobian above in 10000 times for each fwd and rev modes and compares the time per run.

w/ profiler attached, w/ nan checks:

bbales2@frog:~/math-adjoint$LD_PRELOAD=/usr/lib/libprofiler.so CPUPROFILE=cpu_profile ./test_jacobian 
Reverse mode: 6.23723e-05
Forward mode: 3.69096e-05

w/o profiler attached, w/ nan checks:

bbales2@frog:~/math-adjoint$ ./test_jacobian 
Reverse mode: 6.21455e-05
Forward mode: 3.68801e-05

w/ profiler, w/o nan checks:

bbales2@frog:~/math-adjoint$ LD_PRELOAD=/usr/lib/libprofiler.so CPUPROFILE=cpu_profile ./test_jacobian 
Reverse mode: 6.18991e-05
Forward mode: 2.63117e-05

w/o profiler, w/o nan checks:

bbales2@frog:~/math-adjoint$ ./test_jacobian 
Reverse mode: 6.29211e-05
Forward mode: 2.61457e-05

I tried to collect a perf graph that showed the is_nans. Didn’t seem to work. Maybe I was running with -O1 the first time I did it? Anyway looks like it’s being inlined now so it’s not showing up in the stack traces.

Thanks—looks like the profiler’s not getting in the way.

Most likely, which is why I was suspicious. We’ve seen those numbers go from like 50% to like 1% in some cases.

In summary, I agree the NaN thing has to go. Not only for speed, but for simplicity and clarity about what’s going on.

We’ve seen those numbers go from like 50% to like 1% in some cases.

Yeah, I have difficulty keeping all these little benchmarks things coherent (and if you ask me a week from now…). The garden of forking benchmarks is treacherous :/. What would be really funny is if this problem turned out to be faster with integrate_ode_rk45. I’ve just taken for granted it’s a diffusion problem so it’ll lead to stiff ODEs, but I didn’t exactly check that assumption.