Speeding up an ODE model

Could someone please help me to speed up this code? It is extremely slow even for a single chain.

functions{
  vector PK3cmt(real t, vector states, vector theta){
    real ka = theta[1];
    real CL = theta[2];
    real Vc = theta[3]; 
    real Q = theta[4]; 
    real Vp = theta[5]; 
    real gut = states[1];
    real cent = states[2]; 
    real peri = states[3];
    vector[3] dif; 
    dif[1] = - ka * gut;
    dif[2] = ka*gut - ((CL / Vc) + (Q / Vc))*cent + (Q * peri) / Vp; 
    dif[3] = (Q*cent)/ Vc - (Q*peri)/Vp; 
    return dif; 
  }
}
data{
  int N;  // N = 1320 total number of events 
  int nObs; // nObs = 1020 the number of measurements
  int ID; // ID = 20 the number of patients
  int id[N];
  int iObs[nObs]; 
  vector<lower = 0>[N] times; 
  int evid[N]; 
  int cmt[N]; 
  vector<lower = 0>[N] amt; 
  vector<lower=0>[nObs] cObs; 
  vector<lower=0>[ID] weight; 
}
transformed data {
  int nTheta = 5;
  int nCmt = 3;
}
parameters{
  vector<lower=0>[nTheta] theta_bar; //{ka, CL, Vc, Q, Vp},  population parameters
  vector<lower=0>[nTheta] sigma_bar;    
  real<lower = 0> sigma;
  matrix[ID, 5] myz;                // standard normal noise, inter-individuals and inter-parameters variation 
}
model{
  matrix[ID, nTheta] theta; 
  for(j in 1:5) theta[, j] = theta_bar[j] + sigma_bar[j] * myz[, j];    // individual parameters
  
  theta[, 2] = theta[, 2] .* exp(0.75 * log(weight/70));    // adjusting CL for the weight of the patient 
  theta[, 3] = theta[, 3] .* weight/70;                     // adjusting Vc for the weight
  theta[, 4] = theta[, 4] .* exp(0.75 * log(weight/70));    // adjusting Q for the weight 
  theta[, 5] = theta[, 5] .* weight/70;                     // adjusting Vp for the weight
  
  vector[N] traj;
  int i = 1; 
  for(k in 1:ID){
    vector[nCmt] U = to_vector([0.0, 0, 0]); 
    while (id[i] == k){
      if (evid[i] == 1) U[cmt[i]] += amt[i];
      else {
        U = ode_rk45(PK3cmt, U, times[i-1], {times[i]}, to_vector(theta[k]))[1];
        traj[i] = U[2] / theta[k, 3];
      }
      i +=1; 
    }
  }
  vector[nObs] concentrationHat = traj[iObs];
  // priors:
    theta_bar[1] ~ lognormal(log(2.5), 1);
  theta_bar[2] ~ lognormal(log(10), 0.25); 
  theta_bar[3] ~ lognormal(log(35), 0.25);
  theta_bar[4] ~ lognormal(log(15), 0.5);
  theta_bar[5] ~ lognormal(log(105), 0.5);
  sigma ~ normal(0, 0.5);
  to_vector(myz) ~ normal(0, 1);
  sigma_bar ~ normal(0, 0.5); 
  //likelihood:
    cObs ~ lognormal(log(concentrationHat), sigma);
}

 H3cmtpop.m = stan_model(file = "H3cmtpop.stan")

init <- function() {
  list(theta_bar = c(exp(rnorm(1, log(2), 0.2)),
                    exp(rnorm(1, log(10), 0.2)),
                     exp(rnorm(1, log(35), 0.2)),
                     exp(rnorm(1, log(15), 0.2)),
                   exp(rnorm(1, log(105), 0.2))), 
       sigma_bar = abs(rnorm(5, 0, 0.5)), 
       sigma = abs(rnorm(1, 0, 0.5)), 
       myz = matrix(rnorm(100), 20, 5))}
    
test.f = sampling(H3cmtpop.m, data=H3cmtpop.dat, iter = 2000, cores = 1, chains = 1, control=list(adapt_delta=0.85), 
                      init = init)

Is a 3 cmt model really needed? Just to make sure that simpler models have been tried…

Can you - to start - set sigma_bar to be known?

What should also speedup things is to integrate over all times for a given patient at once and move the dosing into the ode rhs function. Ideally you represent the dosing as a continuous input function (sum of normals weighted by dose with small standard deviation is something I used in the past). As your absorption is first order this equation can be solved and then you input is the solution to this.

BTW… technically I would refer to this as a 2-cmt model with 1 order absorption rather than a 3 cmt model.

Glossing over the ODE - this looks analytically solvable. So unless you want to go to non-linear elimination systems…solve it! The dosing can be solved via linear superposition.

I’d strongly recommend to run this model with parallelisation using reduce_sum. Also set you relative and absolute tolerances to targets you really need rather using the strict defaults (which are expensive).

1 Like