Speed up sampling of forced ODE model

Hi,

I am trying to fit an electrochemistry ODE model using Stan. The gradient of the ODE depends on an exogenous variable that changes over time (the voltage ‘V’), and so I followed the advice from: Forced ODEs, a start for a case study? to write the model itself.

I have attempted to run the model but it seems really really slow; I’m guessing it will take about a week to run 200 iterations of NUTS. When I tried a simpler version of the model it took around two days and then it wasn’t close to convergence, and so I am a little loathed to try a week.

Does anyone have any ideas as to how to speed up the sampling? I’m guessing a lot of time is being taken to find the required value of the voltage at each time step using the find_interval_elem function that I use, but I can’t really see another way through this. Best, Ben

functions{
  int find_interval_elem(real x, vector sorted, int start_ind){
    int res;
    int N;
    int max_iter;
    real left;
    real right;
    int left_ind;
    int right_ind;
    int iter;

    N = num_elements(sorted);

    if(N == 0) return(0);

    left_ind  = start_ind;
    right_ind = N;

    max_iter = 100 * N;
    left  = sorted[left_ind ] - x;
    right = sorted[right_ind] - x;

    if(0 <= left)  return(left_ind-1);
    if(0 == right) return(N-1);
    if(0 >  right) return(N);

    iter = 1;
    while((right_ind - left_ind) > 1  && iter != max_iter) {
      int mid_ind;
      real mid;
      // is there a controlled way without being yelled at with a
      // warning?
      mid_ind = (left_ind + right_ind) / 2;
      mid = sorted[mid_ind] - x;
      if (mid == 0) return(mid_ind-1);
      if (left  * mid < 0) { right = mid; right_ind = mid_ind; }
      if (right * mid < 0) { left  = mid; left_ind  = mid_ind; }
      iter = iter + 1;
    }
    if(iter == max_iter)
      print("Maximum number of iterations reached.");
    return(left_ind);
  }
  
  real[] deriv_aslanidi(real t, real[] I, real[] theta, real[] x_r, int[] x_i){
    
    int aLen = x_i[1];
    vector[aLen] ts = to_vector(x_r[1:aLen]);
    vector[aLen] V = to_vector(x_r[(aLen+1):(2*aLen)]);
    int aT = find_interval_elem(t, ts, 1);
    real aV = (aT==0) ? V[1] : V[aT];
    
    real xtau = theta[1] / (1 + exp(aV/ theta[2])) + theta[3];
    real xinf = 1 / (1 + exp(-(aV + theta[4]) / theta[5]));
    
    real dydt[1];
    dydt[1] = (xinf - I[1]) / xtau;
    return dydt;
  }
  
  vector solve_aslanidi_forced_ode(real[] ts, real X0, real[] theta, real[] V, real t0){
    int x_i[1];
    real X_Kr[size(V),1];
    vector[size(V)] I;
    x_i[1] = size(V);
    
    X_Kr = integrate_ode_rk45(deriv_aslanidi, rep_array(X0, 1), t0, ts, theta, to_array_1d(append_row(to_vector(ts), to_vector(V))), x_i);

    for(i in 1:x_i[1]){
      real t = ts[i];
      int aT = find_interval_elem(t, to_vector(ts), 1);
      real aV = (aT==0) ? V[1] : V[aT];
      real rInf = 1 / (1 + exp((aV + theta[6]) / theta[7]));
      I[i] = theta[8] * X_Kr[i,1] * rInf * (aV + 85);
    }
    

    return(I);
  }
  
  real calculateLogLikelihood(real[] I, real[] ts, real X0, real[] theta, real[] V, real t0, real sigma, int N){
    
  vector[N] I_int;
  vector[N] lLogProb;
  
  // solve ODE using stiff solver
  I_int = solve_aslanidi_forced_ode(ts, X0, theta, V,-0.1);
  
  // likelihood
  for(i in 1:N){
    lLogProb[i] = normal_lpdf(I[i]|I_int[i],sigma);
  }
  return(sum(lLogProb));
  }
}

data{
  int N;
  real V[N];
  real I[N];
  real ts[N];
  real t0;
}

transformed data {
  int x_i[0];
}

parameters{
  real<lower=0> p1;     // ms
  real<lower=0> p2;     // mV
  real<lower=0> p3;     // ms
  real<lower=0> p4;     // mV
  real<lower=0> p5;     // mV
  real p6;              // mV
  real<lower=0> p7;     // mV
  real<lower=0> p8;
  real<lower=0,upper=1> X0;
  real<lower=0> sigma;
}

transformed parameters{
  real theta[8];
  theta[1] = p1;
  theta[2] = p2;
  theta[3] = p3;
  theta[4] = p4;
  theta[5] = p5;
  theta[6] = p6;
  theta[7] = p7;
  theta[8] = p8;
}

model{
  
  target += calculateLogLikelihood(I, ts, X0, theta, V, -0.01, sigma, N);
  
  //priors
  p1 ~ normal(900,500);
  p2 ~ normal(5,1);
  p3 ~ normal(100,10);
  p4 ~ normal(0.1,0.02);
  p5 ~ normal(12.25,3);
  p6 ~ normal(-5.6,1);
  p7 ~ normal(20.4,3);
  p8 ~ normal(0.01,0.001);
  sigma ~ normal(1,0.1);
}
1 Like

Hi!

You tried the stiff solver already?

Currently you are using the defaults for the tolerances, so have a look at my post here:

The default tolerances are quite conservative. So it is important to adapt the absolute tolerance to a meaningful value (the relative tolerance may also be increased, but from experience to a lesser extent).

Also, if I see that right, then you can just vectorize your for loop for the likelihood. If you write normal_lpdf(I | I_int, sigma) you get exactly the sum which you are forming afterwards.

I hope this helps.

Sebastian

Hi Sebastian,

Thanks for your message. I have tried the stiff solver and doesn’t seem to
make much difference (unfortunately I am having issues benchmarking it
because expose_stan_functions doesn’t seem to work in this case:

).

Reducing the error tolerances does make a difference to runtime – I see a
factor of three decrease when using 10^-3 compared to the default (10^-6
absolute and relative tolerances). However with lower error tolerances I do
see a visible difference in the solution; there are the wavy lines that are
often indicative of numerical instability. Perhaps around 10^-5 is best for
my case (which has a speed-up of around 1/2).

The vectorising of the likelihood, I missed, thank you. However don’t think
this is really the bottleneck here.

If there are any other ideas here, let me know!

Best,

Ben

Sorry – I have just tried the bdf solver and it makes a hell of a difference! About a factor of 1000x different in the estimated time to do 1000 transitions!

However I am still quite keen to check that the error tolerances that I am using produce reasonable solutions. Is there a way to expose this function (outside of expose_stan_functions) which I can use?

Best,

Ben

1 Like

The stiff solver cannot be exposed to R at the current state of affairs for technical reasons. You can only benchmark it from a running Stan program.

From my experience you should leave the relative tolerance to be quite low (1E-7 or so) and you can go up with the absolute one to 1E-4, but as you note rightfully, the solution ought not to change.

The vectorization won’t solve your problem, yes, but it’s not a bad idea either.

Sebastian

Hi!

Yup, if your problem is stiff then you should use a stiff solver. I have seen such magic speedups myself. Its awesome when it comes out like this.

Just do runs with varying tolerances and compare results. Also watch for the stepsize of the sampler which you get with get_sampler_params.

Sebastian

Yes, it’s great – I have sometimes found a reasonable speed up with stiff
ODE solvers, but never quite this big!

Is an equivalent to the bdf integrator available in any Python or Matlab
packages which you know of? As I say it would be good to do a few test runs
to see whether I get a visibly-different result.

Thanks again!

Best,

Ben

You can work around the problem in a simple way: Just use the “Fixed_param” sampler. Then your initial values become the parameters being used. Do that with warmup=0 and sampling=1 as iteration numbers and you essentially get the same output (and do place the desired output in the generated quantities block). Its just a bit slower and more tedious to work with.

I always like to hear these stiff ODE success stories, best motivation to improve that thing even further.

The solver behind that is the CVODES solver, but I don’t think you need to look for other packages.

Sebastian