Hello! I am new to stan and it’s my first time posting here. I am analyzing data from an epidemiology study and use catalytic epidemiology model. My input is a 4 dimension matrix of infected cases, specific to age, municipality, year, and sex. Since we divide the dimensions to many pieces, there are many combinations with 0 cases (so to some extent sparse data). I am using cmdstanr on clusters. The model works well on 400 municipalities, but the run time is about 7 hours on HPC (4 chains in total; one core for one chain). However, I’ve had a lot of difficulty fitting the model to the whole dataset (5570 municipalities in total). It cannot finish in a suitable amount of time. I am looking for advice about how to make the model more efficient. Thank you so much for any suggestions!
//----- Time-dependent catalytic model -----//
data {
int nA; // N age groups
int nT; // N time points
int nL; // N locations
int nCasePoints;
int nPoints;
array[2,nL,nT,nA] int cases; //Cases
array[2,nL] matrix[nT,nA] pop; // population
array[2,nL,nT,nA] int pointIndex; // index of each case point for log-likelihood vector
array[nA] int aMin; // index for age groups,Arrays defining the age group range (nA age groups thus nA minimum age) indices
array[nA] int aMax; // index for age groups
}
parameters {
array[nL,nT] real<upper=-1> log_lambda; // time-varying FOI (log scale)
array[2,nA] real logit_rho; // sex and age-dependent reporting rates (logit scale)
real<lower=1> phi; // overdispersion parameter
}
transformed parameters {
// use input data and parameters to calculate intermediate values of estimates
array[2,nL] matrix<lower=0, upper=1>[nT,100] S; // susceptible
array[2,nL] matrix<lower=0, upper=1>[nT,100] I; // infected
array[2,nL] matrix<lower=0, upper=1>[nT,100] R; // immune
array[2,nL] matrix<lower=0, upper=1>[nT,nA] Sg; // aggregated to age groups (since later need to compare with the observed data which aggregated to age groups)
array[2,nL] matrix<lower=0, upper=1>[nT,nA] Ig;
array[2,nL] matrix<lower=0, upper=1>[nT,nA] Rg;
array[2,nL,nT,nA] real pCases; // predicted cases
array[nL,nT] real lambda = exp(log_lambda); // time-varying FOI (linear scale)
array[2,nA] real rho=inv_logit(logit_rho); // sex and age-dependent reporting rates (linear scale)
array[2,nL,nT,nA] real ITot; // number of infections
// initial conditions (first year of transmission)
// initialize all the data for t=1, assume proportion of infected (I) is proportional to FOI (lambda)
for(s in 1:2){
for(l in 1:nL){
S[s,l,1,] = 1 - lambda[l,1]*rep_row_vector(1,100);
I[s,l,1,] = lambda[l,1]*rep_row_vector(1,100);
R[s,l,1,] = lambda[l,1]*rep_row_vector(1,100);
}
}
// loop through subsequent yearly timesteps
for(s in 1:2){
// first setup the data for age=1; at age=1, all time/location/sex are susceptible with no infection, so S=1; I/R=0
for (t in 2:nT){
S[s,1:nL,t,1] = rep_array(1,nL); // new babies
I[s,1:nL,t,1] = rep_array(0,nL);
R[s,1:nL,t,1] = rep_array(0,nL);
}
// then use the age=1 data (for all sex/time/location) to loop and get data from age=2 to An
for(l in 1:nL){
for(t in 2:nT){
S[s,l,t,2:100] = S[s,l,t-1,1:99] - lambda[l,t]*S[s,l,t-1,1:99];
I[s,l,t,2:100] = lambda[l,t]*S[s,l,t-1,1:99];
R[s,l,t,2:100] = 1-S[s,l,t,2:100];
}
}
}
// aggregate to age groups
for(s in 1:2) for(l in 1:nL) for(a in 1:nA) for(t in 1:nT) {
Sg[s,l,t,a] = mean(S[s,l,t,aMin[a]:aMax[a]]);
Ig[s,l,t,a] = mean(I[s,l,t,aMin[a]:aMax[a]]);
Rg[s,l,t,a] = mean(R[s,l,t,aMin[a]:aMax[a]]);
ITot[s,l,t,a] = Ig[s,l,t,a]*pop[s,l,t,a];
pCases[s,l,t,a] = rho[s,a]*ITot[s,l,t,a];
}
}
model {
//--- Priors ---//
for(l in 1:nL){
log_lambda[l,] ~ normal(-6,1);
}
phi ~ normal(4,1);
for(s in 1:2){
logit_rho[s,] ~ normal(-5,0.5);
}
//--- Likelihood ---//
//for(t in 1:nT) for (s in 1:2) cases[t,,s] ~ poisson(pCases[s,t,]);
for (s in 1:2) for (l in 1:nL) for(t in 1:nT) cases[s,l,t,] ~ neg_binomial_2(pCases[s,l,t,], phi);
}
generated quantities {
array[nPoints] real log_lik;
real log_lik_sum = 0.0;
real rmse = 0.0;
for (s in 1:2){
for(l in 1:nL){
for(t in 1:nT){
for(a in 1:nA){
log_lik[pointIndex[s,l,t,a]] = neg_binomial_2_lpmf(cases[s,l,t,a] | pCases[s,l,t,a], phi);
log_lik_sum += log_lik[pointIndex[s,l,t,a]];
rmse += (pCases[s,l,t,a] - cases[s,l,t,a])^2;
}
}
}
}
rmse = (rmse/nPoints) ^ .5;
}