Improve specification and sampling of SIS compartmental model

Setup

I will preface this by stating that I am relatively new to Stan, though I have enjoyed it at arms length via {brms} for some time.

I have gone through the excellent Stan + S(E)IR workbook and it inspired me to try my hand at a similar model for sexually transmitted infections (STIs). However, I suspect my approach may benefit from some pointers on Stan coding and optimization.

The following set of R+Stan code takes over 3 hours to fit (6 chains/cores, 2000 iter), in part due to one pesky slow chain; although this may be normal for the amount of data provided, scope of the priors, and amount of parameters (perhaps this is actually excellent for the posterior simulation?). I can also easily conceive more complex models if the compartments are further divided by age and I worry about the time to fit such models.

Without getting into a rabbit hole about computer resources and the like, is this typical for a model of this size? Are there ways to further improve scaling, or perhaps swapping to a different fitting procedure? Any pointers on improving the Stan code and approach would be greatly appreciated.

Please see below for details; for the full model and computer details, skip to the bottom.

STI Math+Prob Model

To model an STI (let’s say gonorrhea), it is logical to start with SIS compartments. To account for heterogenous mixing of transmission of infections, a contact matrix is required. Due to observed information being from surveillance data, and the true amount of susceptible being unknown, they will be accounted for with measurement error parameters. Since the data is reported cases (incidence), this also needs to be accounted for as the true infected population is unknown.

Data

We will use some fake data that takes 3 years of reported case data (by quarter) and the region’s entire population estimates (as an estimate of the possible susceptibles).

Although I would prefer it as a parameter, I set the contact matrix to preferentially have mixing within the same categories (high risk with high risk = 75%); this makes sense based upon core theory of STI spread.

data_sis <- list(ntime = 12, 
                 new_cases = c(1000, 1050, 1100, 1300, 1350, 1250, 1200, 1220, 1250, 1270, 1400, 1380),
                 pop_sus = c(rep(4e6, 5), rep(4.2e6, 4), rep(4.4e6, 4)), # n of 13 for init 0 conditions
                 ts =  1:12, 
                 contact_matrix = matrix(c(0.75, 0.25, 0.25, 0.75), nrow = 2))

This is then provided to Stan in the chunks below…

data {
  int<lower=0> ntime;
  real<lower=0> new_cases[ntime]; // New Infs
  real<lower=0> pop_sus[ntime+1]; // Obs possible susp
  real<lower=0> ts[ntime];
  matrix[2,2] contact_matrix ;
}

transformed data {
  int<lower=0> ntime_w0;
  ntime_w0 = ntime + 1; // For initial suscep at t0
}

Parameters

There are about 11 parameters to estimate before we get to the SIS model. A few of note:

  • p_i: proportion of cases actually captured through reporting (0-1)
  • p_s: magnitude of entire population compared to those actually susceptible (>0).
  • frac: fraction of the susceptibles in low or high STI risk groups
  • c: Average contacts in low and high risk groups
parameters {
  vector<lower=0>[2] Nstate0; // Initial Sus and Inf
  real<lower=0> s_sigma; // Overall var S
  real<lower=0> i_sigma; // Overall var I
  real<lower=0,upper=1> beta; // Inf prob (foi = beta * contact rate)
  real<lower=0> gamma; // Recovery rate
  real<lower=0, upper=1> frac; // fraction of high cont
  vector<lower=0>[2] c; // 1 low, 2 high
  real<lower=0> p_s; // Prop of population actually susceptible
  real<lower=0,upper=1> p_i; // Surveillance prop of cases
}

SIS Model Functions

Three functions are used. The first is the SIS compartmental model that allows for heterogenous mixing through a matrix of contacts (high and low groups). There are then two functions to calculate the incidence and recovered, as I did not believe it was straight forward to return this from the SIS ODE function.

functions {
  
  vector sis(real t, vector state, matrix probM, vector c, real beta, real gamma) {
    
    vector[4] dydt;
    
    // Initial predicted states
    vector[2] S = state[1:2]; // Lo, Hi
    vector[2] I = state[3:4]; // Lo, Hi
    vector[2] N = S + I;
    
    // FOI, one for LO other for HI (phi[1] = FOI for lo, phi[2 = FOI for hi])
    vector[2] phi = (beta * c) .* to_vector(probM * to_matrix(I ./ N)); 
    vector[2] recov = gamma .* I;
    
    // S states (lo, hi)
    dydt[1] = -(phi[1] *S[1]) + recov[1];
    dydt[2] = -(phi[2] *S[2]) + recov[2];
    
    // I states (lo, hi)
    dydt[3] = (phi[1]*S[1]) - recov[1]; 
    dydt[4] = (phi[2]*S[2]) - recov[2];
    
    return dydt;
  }
  
  // Calc incidence from states and recovered (since ODE wont return that value)
  vector incidR(vector It1, vector It2, vector recov) {
    vector[2] incid;
    incid[1] = (It2[1] - It1[1]) + recov[1]; //lo
    incid[2] = (It2[2] - It1[2]) + recov[2]; //hi
    return incid;
  }
  
// Calc recovery
  vector recovR(vector I, real gamma){
    return gamma .* I;
  }  
}

ODE Params

There are four (4) possible states, lo or hi STI risk for S and I compartments. As the initial at risk and infected is unknown, this is treated as a parameter and their fraction is used to split into lo/hi risk groups. The ODE calculations then populate the remaining states across the time period. The incidence is determined based upon the ODE states and recovery parameter for that round as it is needed for the likelihood.

transformed parameters {
  array[ntime_w0] vector<lower=0>[4] y; // Sus and Inf states after 0 time
  array[ntime] vector<lower=0>[4] rates; // Sus and Inf states after 0 time
  
  // Initial conditions
  y[1,1] = frac * Nstate0[1]; // S0lo
  y[1,2] = (1- frac) * Nstate0[1]; // S0hi
  y[1,3] = frac * Nstate0[2]; // I0lo
  y[1,4] = (1- frac) * Nstate0[2]; // I0hi
  
  y[2:ntime_w0, 1:4] = ode_rk45(sis, y[1,1:4], 0, ts, contact_matrix, c, beta, gamma); 
  
  // One step less (b/c incidence)
  for(t in 1:ntime) {
    rates[t, 3:4] = recovR(y[t, 3:4], gamma); // ntime vector
    rates[t, 1:2] = incidR(y[t, 3:4], y[t+1, 3:4], rates[t, 3:4]); // Provide I across states using recovR 
  }  
}

Priors & Likelihood Model

Although other likelihood models I’m sure would also work (e.g. neg-binomial), I decided to use the lognormal for its positive constraint and simplicity. For example, the observed cases is estimated from:

i \sim logNormal(log(p_i *Incid), \sigma_i)

The priors for the initial states have a large magnitude difference, based upon what we observe for the cases and population size. The transmission probability is assumed to be high, when contacts do occur. Average contacts among the high risk group are assumed to be several times larger. The fraction of the susceptibles expected to be high risk of STIs is under 10%.

The loglikelihoods were calculated in the generated quantities block using lognormal_lpdf for both the s_t and i_t fits, and then appended.

model {
  // Priors
  Nstate0[1] ~ lognormal(log(1e6), 0.5);
  Nstate0[2] ~ lognormal(log(1e4), 1);
  s_sigma ~ exponential(1);
  i_sigma ~ exponential(1);
  beta ~ beta(5, 2.5); // Recall its dt per quarter
  gamma ~ normal(4, 1.5); // Recovery rate, time period by Q
  frac ~ beta(3, 100); // Most likely under 10%
  c[1]~ normal(.25, 2); // Lo
  c[2]~ normal(4, 2); // Hi
  p_s ~ normal(2, 2); // Truncated normal 
  p_i ~ beta(40, 200);
    
  // Likelihood (only worked as a loop)
  for (t in 1:ntime_w0) { 
     pop_sus[t] ~ lognormal(log(y[t,1] * p_s), s_sigma);      
      // change in incid is 1 less in size than all suscep
      if (t < ntime_w0) {
        new_cases[t] ~ lognormal(log(rates[t,1] * p_i), i_sigma);
      }
   }  
}

Run Model

Due to prior runs being difficult to fit, the adapt_delta and treedepth were increased. Less iters and more cores were added to assist with speed up.

gono_model <- rstan::stan_model('mystanmodel.stan')
rstan::sampling(gono_model,
                data = data_sis,
                seed = 1294, cores = 6, chains = 6, iter = 2000,
                control = list(adapt_delta = 0.99,
                               max_treedepth = 15))

Computer Details

  • R: 4.3.1
  • RStan version: 2.32.6
  • RStudio: 2023.06.1
  • OS: Windows 10
  • CPU: Intel i9-10900KF @ 3.7 GHz
  • RAM: 60 GBs

Full Model

functions {
  
  vector sis(real t, vector state, matrix probM, vector c, real beta, real gamma) {
    
    vector[4] dydt;
    
    // Initial predicted states
    vector[2] S = state[1:2]; // Lo, Hi
    vector[2] I = state[3:4]; // Lo, Hi
    vector[2] N = S + I;
    
    // FOI, one for LO other for HI (phi[1] = FOI for lo, phi[2 = FOI for hi])
    vector[2] phi = (beta * c) .* to_vector(probM * to_matrix(I ./ N)); 
    vector[2] recov = gamma .* I;
    
    // S states (lo, hi)
    dydt[1] = -(phi[1] *S[1]) + recov[1];
    dydt[2] = -(phi[2] *S[2]) + recov[2];
    
    // I states (lo, hi)
    dydt[3] = (phi[1]*S[1]) - recov[1]; 
    dydt[4] = (phi[2]*S[2]) - recov[2];
    
    return dydt;
  }
  
  // Calc incidence from states and recovered (since ODE wont return that value)
  vector incidR(vector It1, vector It2, vector recov) {
    vector[2] incid;
    incid[1] = (It2[1] - It1[1]) + recov[1]; //lo
    incid[2] = (It2[2] - It1[2]) + recov[2]; //hi
    return incid;
  }
  
// Calc recovery
  vector recovR(vector I, real gamma){
    return gamma .* I;
  }  
}

data {
  int<lower=0> ntime;
  real<lower=0> new_cases[ntime]; // New Infs
  real<lower=0> pop_sus[ntime+1]; // Obs possible susp
  real<lower=0> ts[ntime];
  matrix[2,2] contact_matrix;
}

transformed data {
  int<lower=0> ntime_w0;
  ntime_w0 = ntime + 1;
}

parameters {
  vector<lower=0>[2] Nstate0; // Initial Sus and Inf
  real<lower=0> s_sigma; // Overall var S
  real<lower=0> i_sigma; // Overall var I
  real<lower=0,upper=1> beta; // Inf prob (foi = beta * contact rate)
  real<lower=0> gamma; // Recovery rate
  real<lower=0, upper=1> frac; // fraction of high cont
  vector<lower=0>[2] c; // 1 low, 2 high
  real<lower=0> p_s; // Prop of population actually susceptible
  real<lower=0,upper=1> p_i; // Surveillance prop of cases
}

transformed parameters {
  array[ntime_w0] vector<lower=0>[4] y; // Sus and Inf states after 0 time
  array[ntime] vector<lower=0>[4] rates; // Sus and Inf states after 0 time
  
  // Initial conditions, from predicted totals...
  y[1,1] = frac * Nstate0[1]; // S0lo
  y[1,2] = (1- frac) * Nstate0[1]; // S0hi
  y[1,3] = frac * Nstate0[2]; // I0lo
  y[1,4] = (1- frac) * Nstate0[2]; // I0hi
  
  y[2:ntime_w0, 1:4] = ode_rk45(sis, y[1,1:4], 0, ts, contact_matrix, c, beta, gamma); 
  
  for(t in 1:ntime) {
    rates[t, 3:4] = recovR(y[t, 3:4], gamma); 
    rates[t, 1:2] = incidR(y[t, 3:4], y[t+1, 3:4], rates[t, 3:4]);     
  }  
}

model {
  // Priors
  Nstate0[1] ~ lognormal(log(1e6), 0.5);
  Nstate0[2] ~ lognormal(log(1e4), 1);
  s_sigma ~ exponential(1);
  i_sigma ~ exponential(1);
  beta ~ beta(5, 2.5); // Recall its dt per quarter
  gamma ~ normal(4, 1.5); // Weigh on more than a few days recovery (dt is quarter... ~1/0.25)
  frac ~ beta(3, 100); // Most likely under 10%
  c[1]~ normal(.25, 2); // Lo
  c[2]~ normal(4, 2); // Hi
  p_s ~ normal(2, 2); // Truncated normal 
  p_i ~ beta(40, 200);
    
  // Likelihood (loop otherwise STAN doesnt know how to multiply)
  for (t in 1:ntime_w0) { 
     pop_sus[t] ~ lognormal(log(y[t,1] * p_s), s_sigma);      
      // change in incid is 1 less in size than all suscep
      if (t < ntime_w0) {
        new_cases[t] ~ lognormal(log(rates[t,1] * p_i), i_sigma);
      }
   }  
}

generated quantities {
  real recov_time = 1 / gamma;  
  array[ntime_w0] real<lower=0> y_pred_s;
  array[ntime] real<lower=0> y_pred_i;  
  vector[ntime_w0] log_likS;
  vector[ntime] log_likI;
  vector[ntime+ntime_w0] log_lik;
  
  for (t in 1:ntime_w0) { 
    y_pred_s[t] = lognormal_rng(log(y[t,1]* p_s), s_sigma);
    log_likS[t] = lognormal_lpdf(pop_sus[t] | log(y[t,1]*p_s), s_sigma);
    if (t < ntime_w0) {
      y_pred_i[t] = lognormal_rng(log(rates[t,1]* p_i), i_sigma);
      log_likI[t] = lognormal_lpdf(new_cases[t] | log(rates[t,1]*p_i), i_sigma);
    }
  }  
  log_lik = append_row(log_likS, log_likI); // Combined loglik  
}

Just a quick update, although it still runs a long time (over 2 hours instead of over 3), further refining the priors has helped reduce the sampling time. Furthermore, I believe I had to sum the two categories in the likelihood: sum(rates[t,1:2])

Thanks for reporting back and sorry that nobody answered. The question is well posed and thorough, but answering this kind of involved question is time consuming.

If you have to set max_treedepth = 15, there’s probably an issue with your model parameterization leading to extreme curvature in some reasons. That can often arise due to hierarchical priors, but you don’t seem to be using centered parameterizations anywhere.

The other issue that can come up is that the ODE will be stiff in the initialization region. Sometimes we’ve found moving to our stiff ODE integrators can help with that and cut down on overall fit time, even though they’re slower than the non-stiff integrators for the bulk of the probability mass.

Something like this is going to be generally problematic:

Nstate0[1] ~ lognormal(log(1e6), 0.5);

The mean on the unconstrained scale is around 14, so it can help to add multipliers to the parameters. But the problem here is that we can’t do this with the lower=0 constraints. So it requires something like this:

parameters{
  vector<offset=log(1e6), multiplier=0.5>[2] log_Nstate0;
  ...
transformed parameters {
  vector<lower=0>[2] Nstate0 = exp(log_Nstate0);
  ...   
model {
  log_Nstate0 ~ normal(log(1e6), 0.5);
  ...

That will make sure it’s initialized in a sensible place, too. Otherwise, if you’re not initializing, this can be a huge help in models that otherwise present problems when initialized randomly. One thing you can do is run it, then extract where you got and use that as an init going forward.

Some of the functions you’re writing down can be done much more easily, as in

dydt[1:2] = -phi .* S + recov;
dydt[3:4] = -dydt[1:2];

Another thing that can really speed up ODE models is within-chain parallelization. If you have to solve a bunch of diff eps, parallelizing can be really effective.

1 Like

Thank you for your detailed reply and helpful input @Bob_Carpenter, evidently I have more to learn.

I will try these suggestions and observe any improvements.

A couple observations and actions I took since my last post:

  • Tinkering with initvalues but it didn’t appear to help considerably (will need to get timings).
  • Simplify by reducing the number of parameters; hard coding the frac, which determines the fraction of people in the two risk categories, cut the run time in half nearly

A couple follow-up questions and things to try:

  • Would trying centered parametrization be worthwhile, this is not something I’ve tried manually in Stan before?
  • Perhaps using a different distribution besides log normal/normal would help, such as negative binomial?

You can speed up ODE computation by using higher tolerance and use Pareto-k diagnostic to check whether the change of tolerance affects the results. See our paper An importance sampling approach for reliable and efficient inference in Bayesian ordinary differential equation models