Setup
I will preface this by stating that I am relatively new to Stan, though I have enjoyed it at arms length via {brms} for some time.
I have gone through the excellent Stan + S(E)IR workbook and it inspired me to try my hand at a similar model for sexually transmitted infections (STIs). However, I suspect my approach may benefit from some pointers on Stan coding and optimization.
The following set of R+Stan code takes over 3 hours to fit (6 chains/cores, 2000 iter), in part due to one pesky slow chain; although this may be normal for the amount of data provided, scope of the priors, and amount of parameters (perhaps this is actually excellent for the posterior simulation?). I can also easily conceive more complex models if the compartments are further divided by age and I worry about the time to fit such models.
Without getting into a rabbit hole about computer resources and the like, is this typical for a model of this size? Are there ways to further improve scaling, or perhaps swapping to a different fitting procedure? Any pointers on improving the Stan code and approach would be greatly appreciated.
Please see below for details; for the full model and computer details, skip to the bottom.
STI Math+Prob Model
To model an STI (let’s say gonorrhea), it is logical to start with SIS compartments. To account for heterogenous mixing of transmission of infections, a contact matrix is required. Due to observed information being from surveillance data, and the true amount of susceptible being unknown, they will be accounted for with measurement error parameters. Since the data is reported cases (incidence), this also needs to be accounted for as the true infected population is unknown.
Data
We will use some fake data that takes 3 years of reported case data (by quarter) and the region’s entire population estimates (as an estimate of the possible susceptibles).
Although I would prefer it as a parameter, I set the contact matrix to preferentially have mixing within the same categories (high risk with high risk = 75%); this makes sense based upon core theory of STI spread.
data_sis <- list(ntime = 12,
new_cases = c(1000, 1050, 1100, 1300, 1350, 1250, 1200, 1220, 1250, 1270, 1400, 1380),
pop_sus = c(rep(4e6, 5), rep(4.2e6, 4), rep(4.4e6, 4)), # n of 13 for init 0 conditions
ts = 1:12,
contact_matrix = matrix(c(0.75, 0.25, 0.25, 0.75), nrow = 2))
This is then provided to Stan in the chunks below…
data {
int<lower=0> ntime;
real<lower=0> new_cases[ntime]; // New Infs
real<lower=0> pop_sus[ntime+1]; // Obs possible susp
real<lower=0> ts[ntime];
matrix[2,2] contact_matrix ;
}
transformed data {
int<lower=0> ntime_w0;
ntime_w0 = ntime + 1; // For initial suscep at t0
}
Parameters
There are about 11 parameters to estimate before we get to the SIS model. A few of note:
p_i
: proportion of cases actually captured through reporting (0-1)p_s
: magnitude of entire population compared to those actually susceptible (>0).frac
: fraction of the susceptibles in low or high STI risk groupsc
: Average contacts in low and high risk groups
parameters {
vector<lower=0>[2] Nstate0; // Initial Sus and Inf
real<lower=0> s_sigma; // Overall var S
real<lower=0> i_sigma; // Overall var I
real<lower=0,upper=1> beta; // Inf prob (foi = beta * contact rate)
real<lower=0> gamma; // Recovery rate
real<lower=0, upper=1> frac; // fraction of high cont
vector<lower=0>[2] c; // 1 low, 2 high
real<lower=0> p_s; // Prop of population actually susceptible
real<lower=0,upper=1> p_i; // Surveillance prop of cases
}
SIS Model Functions
Three functions are used. The first is the SIS compartmental model that allows for heterogenous mixing through a matrix of contacts (high and low groups). There are then two functions to calculate the incidence and recovered, as I did not believe it was straight forward to return this from the SIS ODE function.
functions {
vector sis(real t, vector state, matrix probM, vector c, real beta, real gamma) {
vector[4] dydt;
// Initial predicted states
vector[2] S = state[1:2]; // Lo, Hi
vector[2] I = state[3:4]; // Lo, Hi
vector[2] N = S + I;
// FOI, one for LO other for HI (phi[1] = FOI for lo, phi[2 = FOI for hi])
vector[2] phi = (beta * c) .* to_vector(probM * to_matrix(I ./ N));
vector[2] recov = gamma .* I;
// S states (lo, hi)
dydt[1] = -(phi[1] *S[1]) + recov[1];
dydt[2] = -(phi[2] *S[2]) + recov[2];
// I states (lo, hi)
dydt[3] = (phi[1]*S[1]) - recov[1];
dydt[4] = (phi[2]*S[2]) - recov[2];
return dydt;
}
// Calc incidence from states and recovered (since ODE wont return that value)
vector incidR(vector It1, vector It2, vector recov) {
vector[2] incid;
incid[1] = (It2[1] - It1[1]) + recov[1]; //lo
incid[2] = (It2[2] - It1[2]) + recov[2]; //hi
return incid;
}
// Calc recovery
vector recovR(vector I, real gamma){
return gamma .* I;
}
}
ODE Params
There are four (4) possible states, lo or hi STI risk for S and I compartments. As the initial at risk and infected is unknown, this is treated as a parameter and their fraction is used to split into lo/hi risk groups. The ODE calculations then populate the remaining states across the time period. The incidence is determined based upon the ODE states and recovery parameter for that round as it is needed for the likelihood.
transformed parameters {
array[ntime_w0] vector<lower=0>[4] y; // Sus and Inf states after 0 time
array[ntime] vector<lower=0>[4] rates; // Sus and Inf states after 0 time
// Initial conditions
y[1,1] = frac * Nstate0[1]; // S0lo
y[1,2] = (1- frac) * Nstate0[1]; // S0hi
y[1,3] = frac * Nstate0[2]; // I0lo
y[1,4] = (1- frac) * Nstate0[2]; // I0hi
y[2:ntime_w0, 1:4] = ode_rk45(sis, y[1,1:4], 0, ts, contact_matrix, c, beta, gamma);
// One step less (b/c incidence)
for(t in 1:ntime) {
rates[t, 3:4] = recovR(y[t, 3:4], gamma); // ntime vector
rates[t, 1:2] = incidR(y[t, 3:4], y[t+1, 3:4], rates[t, 3:4]); // Provide I across states using recovR
}
}
Priors & Likelihood Model
Although other likelihood models I’m sure would also work (e.g. neg-binomial
), I decided to use the lognormal
for its positive constraint and simplicity. For example, the observed cases is estimated from:
i \sim logNormal(log(p_i *Incid), \sigma_i)
The priors for the initial states have a large magnitude difference, based upon what we observe for the cases and population size. The transmission probability is assumed to be high, when contacts do occur. Average contacts among the high risk group are assumed to be several times larger. The fraction of the susceptibles expected to be high risk of STIs is under 10%.
The loglikelihoods were calculated in the generated quantities
block using lognormal_lpdf
for both the s_t and i_t fits, and then appended.
model {
// Priors
Nstate0[1] ~ lognormal(log(1e6), 0.5);
Nstate0[2] ~ lognormal(log(1e4), 1);
s_sigma ~ exponential(1);
i_sigma ~ exponential(1);
beta ~ beta(5, 2.5); // Recall its dt per quarter
gamma ~ normal(4, 1.5); // Recovery rate, time period by Q
frac ~ beta(3, 100); // Most likely under 10%
c[1]~ normal(.25, 2); // Lo
c[2]~ normal(4, 2); // Hi
p_s ~ normal(2, 2); // Truncated normal
p_i ~ beta(40, 200);
// Likelihood (only worked as a loop)
for (t in 1:ntime_w0) {
pop_sus[t] ~ lognormal(log(y[t,1] * p_s), s_sigma);
// change in incid is 1 less in size than all suscep
if (t < ntime_w0) {
new_cases[t] ~ lognormal(log(rates[t,1] * p_i), i_sigma);
}
}
}
Run Model
Due to prior runs being difficult to fit, the adapt_delta
and treedepth
were increased. Less iters
and more cores
were added to assist with speed up.
gono_model <- rstan::stan_model('mystanmodel.stan')
rstan::sampling(gono_model,
data = data_sis,
seed = 1294, cores = 6, chains = 6, iter = 2000,
control = list(adapt_delta = 0.99,
max_treedepth = 15))
Computer Details
- R: 4.3.1
- RStan version: 2.32.6
- RStudio: 2023.06.1
- OS: Windows 10
- CPU: Intel i9-10900KF @ 3.7 GHz
- RAM: 60 GBs
Full Model
functions {
vector sis(real t, vector state, matrix probM, vector c, real beta, real gamma) {
vector[4] dydt;
// Initial predicted states
vector[2] S = state[1:2]; // Lo, Hi
vector[2] I = state[3:4]; // Lo, Hi
vector[2] N = S + I;
// FOI, one for LO other for HI (phi[1] = FOI for lo, phi[2 = FOI for hi])
vector[2] phi = (beta * c) .* to_vector(probM * to_matrix(I ./ N));
vector[2] recov = gamma .* I;
// S states (lo, hi)
dydt[1] = -(phi[1] *S[1]) + recov[1];
dydt[2] = -(phi[2] *S[2]) + recov[2];
// I states (lo, hi)
dydt[3] = (phi[1]*S[1]) - recov[1];
dydt[4] = (phi[2]*S[2]) - recov[2];
return dydt;
}
// Calc incidence from states and recovered (since ODE wont return that value)
vector incidR(vector It1, vector It2, vector recov) {
vector[2] incid;
incid[1] = (It2[1] - It1[1]) + recov[1]; //lo
incid[2] = (It2[2] - It1[2]) + recov[2]; //hi
return incid;
}
// Calc recovery
vector recovR(vector I, real gamma){
return gamma .* I;
}
}
data {
int<lower=0> ntime;
real<lower=0> new_cases[ntime]; // New Infs
real<lower=0> pop_sus[ntime+1]; // Obs possible susp
real<lower=0> ts[ntime];
matrix[2,2] contact_matrix;
}
transformed data {
int<lower=0> ntime_w0;
ntime_w0 = ntime + 1;
}
parameters {
vector<lower=0>[2] Nstate0; // Initial Sus and Inf
real<lower=0> s_sigma; // Overall var S
real<lower=0> i_sigma; // Overall var I
real<lower=0,upper=1> beta; // Inf prob (foi = beta * contact rate)
real<lower=0> gamma; // Recovery rate
real<lower=0, upper=1> frac; // fraction of high cont
vector<lower=0>[2] c; // 1 low, 2 high
real<lower=0> p_s; // Prop of population actually susceptible
real<lower=0,upper=1> p_i; // Surveillance prop of cases
}
transformed parameters {
array[ntime_w0] vector<lower=0>[4] y; // Sus and Inf states after 0 time
array[ntime] vector<lower=0>[4] rates; // Sus and Inf states after 0 time
// Initial conditions, from predicted totals...
y[1,1] = frac * Nstate0[1]; // S0lo
y[1,2] = (1- frac) * Nstate0[1]; // S0hi
y[1,3] = frac * Nstate0[2]; // I0lo
y[1,4] = (1- frac) * Nstate0[2]; // I0hi
y[2:ntime_w0, 1:4] = ode_rk45(sis, y[1,1:4], 0, ts, contact_matrix, c, beta, gamma);
for(t in 1:ntime) {
rates[t, 3:4] = recovR(y[t, 3:4], gamma);
rates[t, 1:2] = incidR(y[t, 3:4], y[t+1, 3:4], rates[t, 3:4]);
}
}
model {
// Priors
Nstate0[1] ~ lognormal(log(1e6), 0.5);
Nstate0[2] ~ lognormal(log(1e4), 1);
s_sigma ~ exponential(1);
i_sigma ~ exponential(1);
beta ~ beta(5, 2.5); // Recall its dt per quarter
gamma ~ normal(4, 1.5); // Weigh on more than a few days recovery (dt is quarter... ~1/0.25)
frac ~ beta(3, 100); // Most likely under 10%
c[1]~ normal(.25, 2); // Lo
c[2]~ normal(4, 2); // Hi
p_s ~ normal(2, 2); // Truncated normal
p_i ~ beta(40, 200);
// Likelihood (loop otherwise STAN doesnt know how to multiply)
for (t in 1:ntime_w0) {
pop_sus[t] ~ lognormal(log(y[t,1] * p_s), s_sigma);
// change in incid is 1 less in size than all suscep
if (t < ntime_w0) {
new_cases[t] ~ lognormal(log(rates[t,1] * p_i), i_sigma);
}
}
}
generated quantities {
real recov_time = 1 / gamma;
array[ntime_w0] real<lower=0> y_pred_s;
array[ntime] real<lower=0> y_pred_i;
vector[ntime_w0] log_likS;
vector[ntime] log_likI;
vector[ntime+ntime_w0] log_lik;
for (t in 1:ntime_w0) {
y_pred_s[t] = lognormal_rng(log(y[t,1]* p_s), s_sigma);
log_likS[t] = lognormal_lpdf(pop_sus[t] | log(y[t,1]*p_s), s_sigma);
if (t < ntime_w0) {
y_pred_i[t] = lognormal_rng(log(rates[t,1]* p_i), i_sigma);
log_likI[t] = lognormal_lpdf(new_cases[t] | log(rates[t,1]*p_i), i_sigma);
}
}
log_lik = append_row(log_likS, log_likI); // Combined loglik
}