Could someone please help me to speed up this code? It is extremely slow even for a single chain.
functions{
vector PK3cmt(real t, vector states, vector theta){
real ka = theta[1];
real CL = theta[2];
real Vc = theta[3];
real Q = theta[4];
real Vp = theta[5];
real gut = states[1];
real cent = states[2];
real peri = states[3];
vector[3] dif;
dif[1] = - ka * gut;
dif[2] = ka*gut - ((CL / Vc) + (Q / Vc))*cent + (Q * peri) / Vp;
dif[3] = (Q*cent)/ Vc - (Q*peri)/Vp;
return dif;
}
}
data{
int N; // N = 1320 total number of events
int nObs; // nObs = 1020 the number of measurements
int ID; // ID = 20 the number of patients
int id[N];
int iObs[nObs];
vector<lower = 0>[N] times;
int evid[N];
int cmt[N];
vector<lower = 0>[N] amt;
vector<lower=0>[nObs] cObs;
vector<lower=0>[ID] weight;
}
transformed data {
int nTheta = 5;
int nCmt = 3;
}
parameters{
vector<lower=0>[nTheta] theta_bar; //{ka, CL, Vc, Q, Vp}, population parameters
vector<lower=0>[nTheta] sigma_bar;
real<lower = 0> sigma;
matrix[ID, 5] myz; // standard normal noise, inter-individuals and inter-parameters variation
}
model{
matrix[ID, nTheta] theta;
for(j in 1:5) theta[, j] = theta_bar[j] + sigma_bar[j] * myz[, j]; // individual parameters
theta[, 2] = theta[, 2] .* exp(0.75 * log(weight/70)); // adjusting CL for the weight of the patient
theta[, 3] = theta[, 3] .* weight/70; // adjusting Vc for the weight
theta[, 4] = theta[, 4] .* exp(0.75 * log(weight/70)); // adjusting Q for the weight
theta[, 5] = theta[, 5] .* weight/70; // adjusting Vp for the weight
vector[N] traj;
int i = 1;
for(k in 1:ID){
vector[nCmt] U = to_vector([0.0, 0, 0]);
while (id[i] == k){
if (evid[i] == 1) U[cmt[i]] += amt[i];
else {
U = ode_rk45(PK3cmt, U, times[i-1], {times[i]}, to_vector(theta[k]))[1];
traj[i] = U[2] / theta[k, 3];
}
i +=1;
}
}
vector[nObs] concentrationHat = traj[iObs];
// priors:
theta_bar[1] ~ lognormal(log(2.5), 1);
theta_bar[2] ~ lognormal(log(10), 0.25);
theta_bar[3] ~ lognormal(log(35), 0.25);
theta_bar[4] ~ lognormal(log(15), 0.5);
theta_bar[5] ~ lognormal(log(105), 0.5);
sigma ~ normal(0, 0.5);
to_vector(myz) ~ normal(0, 1);
sigma_bar ~ normal(0, 0.5);
//likelihood:
cObs ~ lognormal(log(concentrationHat), sigma);
}
H3cmtpop.m = stan_model(file = "H3cmtpop.stan")
init <- function() {
list(theta_bar = c(exp(rnorm(1, log(2), 0.2)),
exp(rnorm(1, log(10), 0.2)),
exp(rnorm(1, log(35), 0.2)),
exp(rnorm(1, log(15), 0.2)),
exp(rnorm(1, log(105), 0.2))),
sigma_bar = abs(rnorm(5, 0, 0.5)),
sigma = abs(rnorm(1, 0, 0.5)),
myz = matrix(rnorm(100), 20, 5))}
test.f = sampling(H3cmtpop.m, data=H3cmtpop.dat, iter = 2000, cores = 1, chains = 1, control=list(adapt_delta=0.85),
init = init)