Adjoint ODE Prototype - RFC - Please Test

When I make changes to math source I wipe out things with a make clean-all for comdstan when I want to make sure. It should work automatic, but I don’t trust the automatic…

(and really don’t mind doing comparisons with the old, but make sure you use the new, of course)

You wanted a comparison with the old, or did I misunderstand?

I posted a version 2 of the adjoint ODE in the hope that it’s faster than v 1… so yes, a speed comparison between the version to verify from someone else that the changes were for the good.

But as I said, never mind on that. Don’t make it an effort or go into depth with that.

Thanks for implementing this. When thinking about how to select different defaults for the tolerances and other arguments, and what to expose for users, I think the following things should be taken into account (also I would like to know the answer to these)

  1. Which arguments affect the accuracy of the solution y_hat? If the solution is wrong, all depending probabilities will be wrong and the user will be sampling from a biased posterior without getting any warnings about it. And you cannot diagnose this with ESS, Rhat, or any of the MCMC diagnostics in Stan.

  2. Which arguments affect the accuracy of the sensitivities of the solution? These accuracy of these affects how HMC/NUTS works, but you are still sampling from the correct posterior if y_hat is correct (or you can check that you are by diagnosing it with Stan).

  3. How do the other arguments affect the behaviour of the sampler? Especially, what happens if max_num_steps is reached? And not just in terms of what stan_math reports but what does the sampler do if it occurs? Importantly, can you accidentally sample from the wrong posterior by setting this too low?

The speed comparisons are only relevant if you are sampling from the same posterior (i.e y_hat doesn’t change and you aren’t allowed to skip relevant parameter regions).

1 Like

So what do you have to report in terms of results?

The design doc should answer your questions mostly, see Adjoint ode design by wds15 · Pull Request #37 · stan-dev/design-docs · GitHub

In brevity:

  • the forward pass determines y_hat
  • the backward pass calculates the partials of (essentially) the log-lik wrt to the states y
  • the backward quadrature problem calculates the partials of the log-lik wrt to the parameters

What is certainly special about the adjoint ODE method is that each of these components now has their own “numerical life” in the sense of having different techniques we get them and potentially differing tolerances.

I like that document. The good thing to report in speed comparisons is probably ESS/s, but to create informative speed comparisons between forward and adjoint method, I would then create only test cases where both methods

  • use the same solver in forward pass
  • with same tolerances for forward pass
  • have same max_num_steps for forward pass

because then we know that they compute the same y_hat. One can then play with the other parameters to find out good values for them. This might be what you were already doing, I just wasn’t sure what affects y_hat.

Yea, I would see this as an advantage, as you can play with the numerics of the backward solve without messing up y_hat, and with forward method you need to have the same control parameters for the whole extended system.

Hm, generally we cannot expect to sample from the “same” posterior if we are using numerical integrators, can we?

I guess we would like to have some scalable and difficult reference problem with known (analytical) posterior. Do we have such a thing? The harmonic or damped oscillator might be just a little bit too easy.

@wds15 explained the parameters, so I won’t.

  • Generally, throughout the test cases, all methods agreed with each other (in the eye-norm). If there were a large bias due to too loose tolerances, this should show for settings with very tight priors. We did not observe this. Of course, the appropriate tolerances are problem specific, and maybe we should advice users to always do a preliminary exploration of the parameter space as you did in your case study. For this, it would be very convenient if we could do all of this using a single stan file.

  • I believe max_num_steps was so high that it was never reached (1e5)

  • We only have a BDF/Adams adjoint method, but “generally” the rk45 forward method performs much better than the BDF/Adams forward method, which is why this is usually the method I compare against. However, even with equal “forward” tolerances, the solution/step size adaptation will not be the same for the forward and adjoint BDF/Adams method, as the sensitivity errors factor into the step size adaptation (if I’m not mistaken).

Edit:

  • I don’t think reaching max_num_steps is ever a good idea.
  • Problems where the prior dominates the posterior should not be (majorly) affected by inaccuracies in the gradient computation, while problems with wide priors appear to need special handling to ensure convergence in reasonable time.

Adams performs a lot better when dealing only with the N ODE states and not the entire forward problem at once. So I don’t think I would do it like that.

… but do you have some results to report? That would be great.

While I share the desire to “get the same posterior” and the things you say are correct… but I think that we must simplify things in order to make some progress. Looking at how the tolerances put HMC off the rails is extremely difficult.

I would setup things so that HMC will stay (very likely) on its tracks and then vary some of the tuning parameters. That is already enough information and already enough complexity to deal with. I would avoid pushing it too far.

The art here is to do the “right” simplifications, but still get some useful information from the experiments.

Also: The comparison to forward is by now not too useful, I think. At least we have clear examples already where adjoint blows away forward. What is interesting though is if the simplified interface ode_adjoint_tol is already showing gains over forward whenever adjoint should be faster. So it should not depend on the exact tuning parameters of the more complex function in order for adjoint to win over forward (that would be my hope).

If for some system now adjoint is better, then we would like to know which tuning knobs matter (and which don’t) - and which tuning knobs we forgot.

My point was that if two methods use the exact same numerical procedure to compute y_hat, then they are sampling the same posterior.

But I realize now that you are correct here so y_hat won’t be exactly the same even when using same method and same tolerances. So it is hard to do fair comparisons.

If this was the case, wouldn’t it make sense to set max_num_steps to like 1e30 or infinity.

7 posts were split to a new topic: Reaching max_num_steps in the ODE solver

Hi Sebastian,

This looks very promising, and potentially enormously useful to my research, so thank you so much for implementing this!

I started playing around (very humbly, I’m just a user) with the different solvers on an extended SIR-type model. From the first tries the adjoint method with the default settings seems much, much faster (about 15-18 times faster), but also less stable. I sometimes have divergent transitions and mixing issues where it never happens with rk45 or bds. I also have a lot of these warnings that I never encountered before:

Exception: CVODES: CVode Internal t = 293.528 and h = 2.56834e-14 are such that t + h = t on the next step. The solver will continue anyway.

I’m guessing that this has to do with the solver settings. I’m currently testing out different values so I can get some intuition about what matters in my particular case.

I will follow development eagerly!

3 Likes

That’s great to hear!

The error message you show suggests that the solver is struggling with your problem. If you could share some more details of your problem, how you use the old solvers and how you are trying to use the new solver, then maybe we can say more.

Running 15-18x faster sounds really cool… if it would give you a usable result that is.

We would be interested in things as mentioned on the wiki:

Since Easter is approach we would like to settle on details of this thing as much as possible. So thanks for putting your head into experimental work, feedback is very helpful!

@jriou I tried the adjoint solver (also with ‘default’ tuning parameters) on the example we had in the disease transmission tutorial, and found it to be ~3 times slower. This might be expected, since there were only 4 states in this example. But I’m guessing that if we were to stratify by age groups and increase the number of states, the adjoint method yields more benefits.

The additional warning messages may be due to the fact the backward solve can be less stable.

I sometimes have divergent transitions

How many divergent transitions? Could it be a “fluke”, as in something that can happen with the adjoint method as well as the regular integrator? Do they occur consistently across chains or just one chain?

The tuning parameters are really important. @Funko_Unko found that the backwards tolerances are very important for performance. So lowering these did speed up things a lot for him. That can only be done “so much”, of course.

Also: The adjoint ODE method should work well when you have M >> N (so more parameters than states). At least this is my finding so far.

1 Like

Hi Charles,

Exactly, this adjoint method could allow for even more stratification!

The model I’m trying at the moment is slightly more complex with a (simple) implementation of time-dependent parameters. So if you remember the SEIR formulation, the beta parameter now varies in time according to smooth switch functions.

So about the additional questions from @wds15, the total number of states is 5, with 13 varying parameters (declared in the parameter block) and 2 additional fixed parameters. So M>N as you just said. About the other questions in the wiki, the initial values are varying (one of the 13 parameters), but it’s implemented with 0 as initial values and rescaled later (following the advices of @charlesm93 and @bbbales2). It should not be stiff with usual parameter values.

I’ll try to tidy up the code a bit and share some more if you’re interested!

2 Likes

Having the initial state as varying or not makes a big difference for the forward solver method (so all we have so far). For the new adjoint method it should actually not matter at all if the initials vary or not… but maybe confirm that first.

I had 4 states and 8 parameters.

Looking at the size of the ODE we solve, with n states and p parameters, forward method solves n + np states and adjoint 2n + p. So if you have more than 2 parameters and you have a very large number of states, I still expect the adjoint method to perform better.

Happy to take a look at your code and also share what I’ve been using. We definitely should try the published model with stratified ages (i.e. ~50 states), because I expect the speed up will be important there, even if we’re not in the M \gg N regime – that or I’m missing something.

Great idea to fit the old stratified model, I’ll put everything necessary in a clean github repo so we can play around with this!

The adjoint method has a much larger overhead than the forward method. Thus it’s not about what gives you the smaller number with n and p here. The CVODES manual states that adjoint is better - usually - whenever p is greater than n… it’s a matter of trying is my view.

And really: Do try to run with less accurate backwards solve. If these still give you no sampling issues, then things should run faster. Moreover, I found that the Adams solver does relatively well (which it never did for my problems when using forward).