CoDatMo Liverpool & UNINOVE models (slow ODE implementation and Trapezoidal Solver)

Hey guys! I was discussing a possible trapezoidal ODE solver implementation in Stan with @wds15, @Bob_Carpenter, @bbbales2 and @charlesm93 and the discussion took a detour to “what are you doing that you need a trapezoidal solver?”.

I was explaining Liverpool’s CoDatMo model that, according to them, takes between 10-17h to finish sampling (4 parallel chains 2k iterations) with CmdStan. The model is a dual-time-series data using daily deaths and daily NHS 111 COVID calls. The UNINOVE CoDatMo model will consist of two models that uses the same transmission model as Liverpool and also relies on Brazil’s data. The first one (which is the one I am currently working on) is a “stepped-down” version of the model with only one time-series data using daily deaths. This one takes 1h30 to sample on a high-end notebook running Linux and using CmdStan with the trapezoidal solver. I am running a RK45 ODE solver and it is still warming up (I think it 'll sample in 12 days). The second model would be the same as Liverpool’s but using twitter daily symptom mentions.

As the discussion progressed, we decided to move to the forums so that everybody could probably benefit from this discussion as well. As @wds15 noted regarding trapezoidal solvers: “To the point of trapezoidal solvers: Avoiding the use of dedicated ODE solvers and going so simple is tempting, I know. The issue is error control”. With trapezoidal solvers we have very little error control.

Okay, so let me explain the model. The transmission model assumes a single geographical region with a large population of identical (for instance, in terms of age and sex) individuals who come into contact with one another uniformly at random, but do not come into contact with individuals from other areas. It also assumes a closed population, that is no individuals migrate into or out of the population, for example as a result of births, non-COVID related deaths, or changes in permanent residence. The model divides the population into six disease states: (S) susceptible, (E) exposed, (I) infectious, (R) recovered, (T) terminally ill, and (D) dead. The exposed, infectious, and terminally ill disease states are each further partitioned into two substates, making the expected times spent by people in these states follow Erlang distributions. Below is a graphical illustration of the transmission model.

The Erlang distributions arises when we implement the sub-states using the Generalized Linear Trick (Hurtado references below). “The LCT is a technique used to construct mean field ODE models from continuous-time stochastic state transition models where the time an individual spends in a given state (i.e., the dwell time) is Erlang distributed (i.e., gamma distributed with integer shape parameter).”

Through the random mixing of the population, infectious individuals come into contact with and are allowed to transmit the virus to susceptible individuals. A susceptible individual who has become exposed through one of these contacts is not initially infectious. A period of time elapses, while the virus replicates in their body before they become infectious and can transmit the virus onto members of the remaining susceptible population. After being infectious for some time, the individual may recover and become indefinitely immune to reinfection. Should the individual fail to recover, however, they will become terminally ill for a while before, unfortunately, dying of the disease.

The number of individuals in each disease state varies with time according to the following system of ordinary differential equations:

\begin{align} \frac{dS}{dt} &= -\beta \frac{I_1 + I_2}{N} S \\ \frac{dE_1}{dt} &= \beta \frac{I_1 + I_2}{N} S - \frac{2}{d_L} E_1 \\ \frac{dE_2}{dt} &= \frac{2}{d_L} (E_1 - E_2) \\ \frac{dI_1}{dt} &= \frac{2}{d_L} E_2 - \frac{2}{d_I} I_1 \\ \frac{dI_2}{dt} &= \frac{2}{d_I} (I_1 - I_2) \\ \frac{dR}{dt} &= \frac{2}{d_I} I_2 \left(1 - \omega\right) \\ \frac{dT_1}{dt} &= \frac{2}{d_I} I_2 \omega - \frac{2}{d_T} T_1 \\ \frac{dT_2}{dt} &= \frac{2}{d_T} (T_1 - T_2) \\ \frac{dD}{dt} &= \frac{2}{d_T} T_2 \end{align}

The models can be found here: UNINOVE_Sao_Paulo/SEIR-model/stan at main · codatmo/UNINOVE_Sao_Paulo · GitHub. The deaths.stan is the deaths-only with the trapezoidal custom solver and the deaths_rk45.stan is the same model using the old interface RK45 solver. And the deaths_new_rk45.stan is an implementation to the new interface RK45 solver (thank you @bbbales2).

The point is: the model is slow to sample and we can possibly remove the custom trapezoidal solver if we could somehow improve it. @wds15 suggested “maybe work out a semi-analytic solution which takes advantage of different time-scales”, while also pointing that "the problem will benefit massively from the adjoint method is my prediction. You have 9 states, but about 130 parameters!!! This would qualify the adjoint ODE method to perform a lot faster. "

EDIT: tagging @s.maskell since he is the Lead Researcher of CoDatMo’s Liverpool and will benefit from this discussion.

References

Hurtado, P. J., & Kirosingh, A. S. (2019). Generalizations of the “Linear Chain Trick”: Incorporating more flexible dwell time distributions into mean field ODE models. Journal of Mathematical Biology , 79 (5), 1831–1883. https://doi.org/10.1007/s00285-019-01412-w

Hurtado, P., & Richards, C. (2020). A procedure for deriving new ODE models: Using the generalized linear chain trick to incorporate phase-type distributed delay and dwell time assumptions. Mathematics in Applied Sciences and Engineering , 1 (4), 410–422. https://doi.org/10.5206/mase/10857

4 Likes

Yeah this comes up from time to time and it seems legit to me. Here’s a previous related post: Costs/benefits of solving large ODE systems as systems of difference equations/in discrete time (and I think that links to another one before it)

In this case I agree with @wds15 it’s probably the scaling of the forward sensitivity problem that is messing things up.

When you implement the ODE solver in Stan with a fixed timestep you don’t have a sensitivity scaling problem. Similarly if you use the prototype adjoint sensitivity thing then the scaling should be better, but then it could still end up being slower if trapezoid is a more appropriate solver than the BDF or Adams solvers in CVODES.

There’s some solver functions @jtimonen wrote for @spinkney 's helpful stan functions repo here (look under functions/odes). The RK4 method there might be faster than the trapezoid method.

2 Likes

Ok, so here is what I think is really killing the performance: The problem as written uses a piece-wise linear approximation to the infections function. For 63 time-intervals a slope and an offset is being fitted. So we have 2x63 parameters just for that… and these are all passed into the ODE RHS. The problem is that per parameter we get again N ODE’s - here N is the number of states, which is 8 in this case… so we end up having to solve just for that about 900 ODEs!!!

BUT… we can completely avoid this issue by doing the integration in steps along the 63 time-intervals. So integrating by time-interval 63x we can simple pass for each ODE solve one offset and one slope…that reduces the extra number of ODEs per solve from 900 to just 18. It’s a bit of a hassle to get the indexing right (and I messed it up in my first attempt), but hopefully I find the time to come up with a working model.

The bottom line here is: NEVER pass in parameters into the ODE RHS unless you have to and really work hard to avoid it.

Ah… and adjoint ODE will be interesting here to try, of course.

6 Likes

@bbbales2 I’ve taken your suggestion. Here is the custom RK4 solver:

vector[] odeint_rk4(real t0, vector y0, real h, int num_steps,
    real[] a0, vector theta, int[] integer_data){

  int d = num_elements(y0);
  vector[d] y[num_steps+1];
  vector[d] k1; vector[d] k2; vector[d] k3; vector[d] k4;
  real t = t0;
  y[1] = y0;
  
  // Integrate at grid of time points
  for(i in 1:num_steps){
    k1 = h * seeiittd(t          , y[i]           , theta, a0, integer_data);
    k2 = h * seeiittd(t + 0.5 * h, y[i] + 0.5 * k1, theta, a0, integer_data);
    k3 = h * seeiittd(t + 0.5 * h, y[i] + 0.5 * k2, theta, a0, integer_data);
    k4 = h * seeiittd(t + h      , y[i] + k3      , theta, a0, integer_data);
    y[i+1] = y[i] + (k1 + 2.0 * k2 + 2.0 * k3 + k4) / 6.0;
    t = t + h;
  }
  
  return(y);
}

File is here: UNINOVE_Sao_Paulo/deaths_custom_rk4.stan at main · codatmo/UNINOVE_Sao_Paulo · GitHub

I am running now don’t know how faster it will go, but from the warmup is seems way faster!

@wds15 I understood what you mean, and it makes total sense. The piece-wise beta functions are killing the model and exploding the number of ODE in the final system. I will probably need more time and effort to try to code it up.

@s.maskell I am tagging also @alphillips and @remoore so that they can benefit from the discussion. @alphillips and @remoore check out the deaths_custom_rk4.stan file linked above. It might make Liverpool’s NHS models run faster and even more accurate since its an explicit fourth-order method and trapeizodal rule is an implicit second-order method. I don’t think stiffness (trapezoidal rule is A-stable) is an issue here (but I might be wrong, 3 months ago I was a totally newbie in ODEs).

2 Likes

@wds15: Thanks; that is a great observation. We’ll certainly address that.

Attached is a version of the model which uses the piece-wise approach. It’s still taking some time, I have to say.

… but there is more to get here: For example, when we know that the infection rate is a linear function, then the ODE dS_dt = -infection_rate; can be solved analytically? In case you can live with a step-wise constant function things get even simpler and possibly more of the equations can be solved. Reducing the number of ODEs is always the best thing to do.

deaths_new_rk45-v3.stan (9.4 KB)

@wds15: I suspect you know and I don’t have an easy way to have a go myself quickly. So: how do the runtimes compare?

No, I was too impatient. Probably one should decrease the number of pieces and also look into model details like initialization and priors. At least I still had to use bdf solvers to make any progress, but you seem to expect a non-stiff solution. For this rather detailed work one should better have more familiarity with the model and the data as compared to myself. In any case, I hope to have put you in a helpful direction.

Just one warning: my code assumes that the initial time is equal to beta_left_t…which it was from the data I saw.

Fair enough. We can do that analysis.

The number of pieces needs to be high, unfortunately. Vaguely similarly, the initialisation and prior have been a focus for us over recent months. So, I’m not sure we can make much headway there.

I do think your pointers have been really helpful though: we want to extend this model to consider much larger numbers of coupled ODEs (about 100x as many) and need to understand how to best configure the ODE aspects (while we build a modified version of Stan that might be able to do the inference in sub-glacial timescales): Our bespoke trapezoidal solver was used here with that future application in mind. It’ll be really interesting to see how each of the candidate solvers work in the current and future context.

Thank you.

I’ve run the model using Brazil’s data (62 weeks)

  • Bespoke Trapezoidal Solver: 1h33
  • @jtimonen RK4 Solver: 2h56
  • new Adjoint Solver by @wds15 : still running but from an estimate it might take around 14h

The old solution using just the RK45 and 900 ODEs would take 12 days. So, I guess this is a nice advance from 12 days to 14 hours. I guess that better analytical solutions to reduce the number of ODEs would reduce sampling time even more.

How do the different solvers compare in estimates / lp__ / diagnostics?

The Adjoint with RK45 is still running. But the trapezoidal vs RK4 CmdStan CSVs files are already in the CoDatMo/UNINOVE_Sao_Paulo repo under folder SEIR-model/results/: deaths/ is the trapezoidal and deaths_rk4 is the RK4. Feel free to run analyses in those CSVs.

1 Like

Here is another variant where I am fully exploiting the new adjoint ODE solver. The cool thing is that the many parameters being passed into the adjoint ODE solver really do not hurt that much as it looks. The sampler does make progress on my screen and this is being compiled with the latest adjoint ode cmdstan:

https://github.com/stan-dev/cmdstan/suites/2733052299/artifacts/60343510

deaths_new_rk45-v2.stan (9.8 KB)

2 Likes

I am running both the deaths_new_rk45-v3.stan with the cmdstan-ode-adjoint-v2 in a 2018 Mac Mini i5-8500B and the deaths_new_rk45-v2.stan with cmdstan-ode-adjoint-v3 in a 2020 Dell G5 with i7-9750H .

Will report back on times for both runs.

You need to put into make/local this line:

CXXFLAGS+=-DSTAN_NO_RANGE_CHECKS

for good performance.

Like?

echo "CXXFLAGS+=-DSTAN_NO_RANGE_CHECKS" > make/local
make build -j4

Looks good

1 Like

Ok so the Dell G5 is running now with the flags

CXXFLAGS+=-DSTAN_NO_RANGE_CHECKS

This new flag is needed in Stan 2.27 onwards to turn off index checking which eats ~20% performance, but you don‘t want this flag while developing a model.

BTW, what did you mean above by „adjoint RK45“ approach of mine?

Sorry I went over the files and there was one or a proposal that you said to numerical solve the betas inside the seeiittd ODE system using some other solver then feed it into a rk_45 or adjoint solver. I will correct the statements above… I made a whole mess of confusing names…