Slow hierarchical ODE model

Hello,

I’ve been trying to implement a hierarchy for a compartmental ODE model on Stan. Previously, my model was running fine when I only had 2 levels (global + 1 level of grouping). I have been trying to build another level, but have been facing an issue of my run time being very long (e.g. 500 iterations took 150 hours to run). I understand that some fixes are reparameterisation, tightening priors, transforming the parameters but honestly I’m having trouble adapting these to my application. I’m not sure how to proceed in making things faster and I’m hoping for some guidance on this. Below is my Stan code, thank you!

functions {
  real[] OneSys(real t, real[] y, real[] params, real[] x_r, int[] x_i) {
  real dydt[4];
  
  dydt[1] = -params[1]*y[1]*y[3];
  dydt[2] = params[1]*y[1]*y[3] - params[2]*y[2];
  dydt[3] = 100*y[2] - 20*y[3] - params[3]*y[4]*y[3];
  dydt[4] = params[4]*y[2]*y[4];
  
  return dydt;
  }
}

data {
  int<lower = 0> L; //no. of rows in the dataset
  int<lower = 0> G; //no. of vacgr
  int<lower = 0> AG; //total no. of agegr across all vacgr
  int N_agpvg[G]; //vector of no. of agegr in each vacgr
  int l_tsolve; //length of vector of times to solve
  real logvirus[L]; //swab data
  real t0; //initial value for t
  int ts[L]; //all days (arranged by individual)
  real t_solve[l_tsolve]; //vector of times to solve
  int pos_vgofag[AG]; //vector of which vg each ag is in, indexing for tinc
  int count_agegr[AG]; //vector of no. of swabs in each agegr
  int pos_agegr[AG]; //vector of indices of first swab for each agegr
}

transformed data {
  real x_r[0];
  int x_i[0];
}

parameters {
  real<lower = 0, upper= 10^-2> beta_par;
  real<lower = 0, upper = 100> delta_par;
  real<lower = 0, upper = 100> gamma_par; 
  real<lower = 0, upper = 1> omega_par[AG];
  real<lower = 0, upper = 10^-2> omega_mu;
  real<lower = 0, upper = 0.1> omega_sigma;
  real<lower = 0, upper = 100> sigma;
  real<lower = 0, upper = 15> inc_per;
  real<lower = 0, upper = 10^8> imm_resp[G];
}

transformed parameters {
  real predVal[L];
  real params[4];
  real y0[4];

  {
    real t_inc[l_tsolve];

    // incubation period is the same across all vaccine groups
    for (l in 1:l_tsolve) {
       t_inc[l] = t_solve[l] + inc_per;
    }

    y0[1] = 10^10; //T0
    y0[2] = 0; //I0
    y0[3] = 10; //V0
  
    params[1] = beta_par;
    params[2] = delta_par;
    params[3] = gamma_par;

    print("beta: ", params[1]);
    print("delta: ", params[2]);

    // iterate over total age groups
    for (ag in 1:AG) {
      real yout[l_tsolve, 4];
      
      y0[4] = imm_resp[pos_vgofag[ag]];
      params[4] = omega_par[ag];
    
      yout = integrate_ode_bdf(OneSys, y0, t0, t_inc, params, x_r, x_i);
      predVal[pos_agegr[ag] : (pos_agegr[ag] + count_agegr[ag] - 1)] = yout[ts[pos_agegr[ag] : (pos_agegr[ag] + count_agegr[ag] - 1)], 3];
     
    }
  }
  
}


model {
  beta_par ~ beta(0.001,5);
  omega_sigma ~ cauchy(0,2);
  omega_par ~ normal(omega_mu, omega_sigma);
  sigma ~ cauchy(0,4);
  inc_per ~ normal(4,2);
  
  for (i in 1:L) {
    logvirus[i] ~ normal(log(predVal[i]),sigma);
  }
}
1 Like
  1. What is the data dimensionality?
  2. Why are you using a stiff-ODE solver (bdf)?

Depending on the size of data, adjoint method (11.2 Ordinary differential equation (ODE) solvers | Stan Functions Reference) might give you a speed-up.

2 Likes
  1. G = 15 while AG = 60 or 64, depending on the dataset that I am using. I tried running the fit with a subset of the data (such that G = 2 and AG = 4) and run time was okay for the number of iterations I used. So I think the long run time issue I was facing was because of the large number of groups when using the whole dataset.
  2. Previously I was using a non-stiff solver which was causing slower run times and I switched to a stiff solver after reading some of the forums, which did improve speed a bit and had no issue with parameter estimates when running on simulated data.

I’ve not used an adjoint solver before, how would switching to this possibly speed up the fit?
Thanks!

Just to update, I looked up the adjoint solver and unfortunately will not be able to use it as the HPC I am using does not have a sufficiently up-to-date version of rstan there.

Instead, I’ve recently tried using non-centred parameterisation for my omega_par parameter instead and this did speed up my model quite a lot. However, the traceplots show that the chains aren’t mixing for all the parameters except the omega_par values.


Would you have any advice on this?