Inference of ODE parameters - Performance

Hi all,

Can I please get some advice on how I might scale the attached growth_ode_est.stan model?

Following the excellent documentation, I’ve plugged in a system of ODEs [1] that I would like to infer parameters from. I can correctly recover simulated values in a short time period when there are few equations (~2min for 2 equations) but quickly get bogged down as the number of observations or number of equations grows (eg ~2-4hrs for 20 equations).

If I simulate the same system in growth_ode_sim.stan with fixed parameters the integrator will complete 1,000 iterations of 20 equations in ~6sec. To me this suggests I am missing some understanding of the complexity involved in estimation and would appreciate any suggestions that might make it more efficient.

Simulation and models can be called from the helper script simple_growth_ode.R.

Thanks for your help,
Andrew

growth_ode_est.stan (2.1 KB)
growth_ode_sim.stan (1.9 KB)
simple_growth_function.R (1.9 KB)
simple_growth_ode.R (2.2 KB)

Credit where its due:
[1] http://onlinelibrary.wiley.com/doi/10.1111/1365-2745.12557/abstract

There’s a few things that could be happening here.

  1. For each posterior sample, Stan (by default) will use anywhere between 2 and 1024 evaluations of the model (and so 2 to 1024 evaluations of your ODE). This is a function of your model and the no u-turn criteria. To figure out how many ODE evaluations are happening per sample, check the treedepth diagnostics. (2^treedepth is roughly how many evaluations you used)

  2. In the course of doing an ODE solve, Stan needs to compute the ODE sensitivities (the derivatives of the output with respect to parameters/initial conditions). The ODE sensitivity problem scales as number_of_ODE_states * (number_of_parameters + 1). You might try switching from the rk45 ODE solver to the bdf one. They’re implemented quite a bit differently, and in the (not so unlikely) case that you have a stiff ODE, the bdf one should be faster.

  3. The ODE solvers are variable stepsize. Seems like you’ve discovered this. If you don’t put a cap on the number of steps in the ODE or put a lot of effort in keeping Stan out of regions that will require very fine timesteps, ODE problems can really bog down. Getting good priors for the ODE stuff is really important to avoid wasting tons of time.

  4. How are the interactions in the ODE scaling? It seems like they could be N^2 (every state interaction with every other state)? At which point you might expect 10x the states to take 100x the time to work with.

Hopefully some of that is useful haha.

1 Like

try switching from the rk45 ODE solver to the bdf one

+1. Using the bdf solver reduced gradient evaluation by a factor of 10.

N^2 scaling and the number evaluations per sample also make a lot of sense, I’ve adjusted my expectations accordingly.

Thanks!

A question: are there any fundamental issues with limiting treedepth?

From my understanding, shorter trees mean NUTS won’t be able to explore quite as efficiently but should still generate valid samples. If Stan is choking on the number of ODE solves, can I limit the number of transitions per sample and generate more samples in total?

fundamental issues with limiting treedepth?

Yeah, unfortunately :D. Treedepth comes from the No-U-Turn sampler. It determines how far the Hamiltonian dynamics are integrated before (very roughly) the sampler finds itself turning around and coming back to where it started.

If you don’t let the sampler go all the way to where it finds its u-turn, you could compromise the quality of the samples. You’ll be able to generate more samples, but there’ll be more correlation sample to sample, which is bad.

If you want to interpret the error in the estimates you’re making, you don’t want to look at the samples you have but the n_eff (number of effective samples) output (n_eff is really how we’re supposed to determine how many samples to generate).

If you’re having treedepth issues (5-6 is the normal range), the thing to do is look for reparameterizations. Check the pairplots on all your parameters. If you have parameters that are tightly correlated and live on a little ridge (instead of a nice round blob), try to think up a way to reparameterize them so they do live in a blob. It’s ridges that are hard to explore (and req. large treedepths).

1 Like