Chains not mixing, exceeding treedepth in epidemiological model

Hi all

I am trying to fit some epidemiological data to a compartmental model. However, when fitting with only 400 (chains=4)iterations (since this already takes 12h), it gives the following problems

  • 13 divergent transitions
  • over 50% of the iterations exceed treedepth
  • the chains do not mix
  • in a prior-predictive-check, the parameters all get sampled as they should, however, sometimes the ode_bdf-integrator gives the error “CVode failed with error flag -1.”, which apparently hints at not finding an accurate solution
  • during the PPC, also sometimes it calculates negative cases numbers in comp_diffM, which obviously does not make sense
  • for most parameters, the n_eff is only 2 or 3

Previously, some more simple model worked perfectly well. With some increased complexity however, and unable to find coding-errors, coming from a biology side and being new to stan, i do not really know how to interpret or solve this issue. Below is the code, Adapt_delta is 0.8 currently.

Below is the traceplot for some parameters, however it looks the same for all of them, like it does not get out of the initial conditions…

If anyone has seen a familiar problem, i would be delighted if you could help me.

functions {
  //function for a new behaviour reproductive number
  real modulate_Re(real t, real t1, real t2,real nu1, real nu2,real eta1, real eta2, real xi1, real xi2) {
    real sigup;
    real sigdown;
    sigup=eta1+(1-eta1)/(1+exp(-xi1*(t-t1-nu1)));
    sigdown=(1-eta2)/(1+exp(-xi2*(t-t2-nu2)));
    
    return(sigup-sigdown);
  }
  
  real[] discretize(int max_days, real shape, real rate){     //function allows delay between symptoms and test
    real disc[max_days];
    real normconstant;
    
    disc[1]=gamma_cdf(0.5,shape,rate); //first step has only half the interval to integrate
    
    for(i in 2:max_days){
      disc[i]=gamma_cdf(i+0.5,shape,rate)-gamma_cdf(i-0.5,shape,rate); //wed want to integrate from day x-0.5 to day x+0.5 but for first different
    }
    normconstant=sum(disc);
    for(i in 1:max_days){
      disc[i]=disc[i]/normconstant; //wed want to integrate from day x-0.5 to day x+0.5 but for first different
    }
    return (disc);
  }
  
  real[] SEIR(real t,
              real[] y,
              real[] theta,
              real[] x_r,
              int[] x_i
  ) {
    int K = x_i[1];
    
    //real tswitch = x_r[1]; // time of control measures
    real dydt[(6*K)]; // SEPIAR (ignoring R) then C 
    real nI; // total infectious
    
    real beta; // transmission rate
    real eta1; // level of Re in August
    real eta2; // level of Re in December
    real t1; // time of increase
    real t2; // ltime of decrease
    real nu1;//delays in increase and decrease
    real nu2;
    real k1; // level of peak in Re
    real k2; // level of peak subtracted from max
    real xi1; // slope of rise of Re
    real xi2; // slope of decrease of Re
    real tau_1; // infection to preclinical
    real tau_2; // preclinical to symptoms (tau_1+tau_2 = incubation)
    real q_P; // contribution of presymptomatics to transmission
    real gt; // generation time
    real mu; // infectious duration for symptomatics
    real psi; // probability of symptoms
    real kappa; // reduced transmissibility of preclinical and asymptomatics
    real p_tswitch; // switch function
    real contact[K*K]; // contact matrix, first K values, corresponds to number of contact between age class 1 and other classes, etc
    real f_inf[K]; // force of infection
    real init[K*2]; // initial values
    real age_dist[K]; // age distribution of the general population
    real pi; // number of cases at t0
    
    // Estimated parameters
    beta = theta[1];
    xi1 = theta[2];
    xi2=theta[3];
    pi = theta[4];
    psi = theta[5];
    eta1=theta[6];
    eta2=theta[7];
    nu1=theta[8];
    nu2=theta[9];
    
    // Fixed parameters
    tau_1 = x_r[1];
    tau_2 = x_r[2];
    q_P = x_r[3];
    gt = x_r[4];
    t1=x_r[5];
    t2=x_r[6];
    
    // Composite parameters
    mu = (1-q_P)/(gt-1/tau_1-1/tau_2);
    kappa = (q_P*tau_2*psi)/((1-q_P)*mu-(1-psi)*q_P*tau_2);
    
    // Contact matrix
    contact = x_r[7:(6+K*K)];
    
    // Initial conditions
    for(k in 1:K){
      age_dist[k] = x_r[6+K*K + k];
      init[k] = age_dist[k] * (1-pi);
      init[K+k] = age_dist[k] * pi;
    }
    
    // Total number of infectious people
    p_tswitch = modulate_Re(t,t1,t2,nu1,nu2,eta1,eta2,xi1,xi2);
    
    // Force of infection by age classes: beta * p_tswitch * sum((number of infected people by age + kappa*number of preclinical by age + kappa*number of asympto) / (total number of people by age) * (number of contact by age))
    for(k in 1:K) {
      f_inf[k] = beta * p_tswitch * sum((to_vector(y[(3*K+1):(4*K)])+kappa*to_vector(y[(2*K+1):(3*K)])+kappa*to_vector(y[(4*K+1):(5*K)]))./ to_vector(age_dist) .* to_vector(contact[(K*(k-1)+1):(k*K)])); 
    }
    
    // Compartments
    for (k in 1:K) {
      // S: susceptible
      dydt[k] = - f_inf[k] * (y[k]+init[k]); 
      // E: incubating (not yet infectious)
      dydt[K+k] = f_inf[k] * (y[k]+init[k]) - tau_1 * (y[K+k]+init[K+k]);
      // P: presymptomatic (incubating and infectious)
      dydt[2*K+k] = tau_1 * (y[K+k]+init[K+k]) - tau_2 * y[2*K+k];
      // I: symptomatic
      dydt[3*K+k] = psi * tau_2 * y[2*K+k] - mu * y[3*K+k];
      // A: asymptomatic
      dydt[4*K+k] = (1-psi) * tau_2 * y[2*K+k] - mu * y[4*K+k];
      // C: cumulative number of infections by date of disease onset
      dydt[5*K+k] = psi * tau_2 * y[2*K+k];
    }
    return(dydt);
  }
}

data {
  // Structure
  int K; // number of age classes
  vector[K] age_dist; // age distribution of the population
  int pop_t; // total population
  //real tswitch; // time of introduction of control measures
  // Controls
  real t0; //starting time
  int t_data; //time of first data
  int S;
  real ts[S]; // time bins
  int inference; // 0: simulating from priors; 1: fit to data
  int doprint;
  // Data to fit
  int D; // number of days with reported incidence
  int incidence_cases[D]; // overal incidence for W weeks
  int incidence_deaths[D]; // overal incidence for W weeks
  int agedistr_cases[K]; // number of cases at tmax for the K age classes
  int agedistr_deaths[K]; // mortality at tmax for the K age classes
  int tests_age_groups[D,K];
  // Priors
  real p_beta;
  real p_eta[2];
  real p_pi[2];
  real p_epsilon[2];
  //real p_rho[2]; replaced by lambda
  real p_phi;
  real p_xi;
  real p_nu;
  real p_psi[2];
  real p_lambda;//prior for initial slope of saturation function
  // Fixed parameters
  real contact[K*K]; // contact matrix
  real p_q_P; // proportion of transmission that is caused by presymptomatics
  real p_incubation; // incubation period
  real p_preclinical; // preclinical period (part of the incubation with possible transmission)
  real p_generation_time; 
  real p_children_trans; // relative transmissibility in children 1-10
  // Fixed corrections
  real p_report_80plus; // fixed ascertainment proportion for ages 80+
    
  // Fixed delays
  int G;
  int D_max;
  real p_gamma[G]; // from onset to death
  real mean_shape;//mean for prior for shape of gamma which captures testing delay
  real mean_rate;//mean for prior for rate of gamma which captures testing delay
  real p_nu1nu2[2];//prior parameters for t1 (2 of them) and t2
  real p_eta1eta2[4];//prior parameters for the end and start levels of the sigmoids
  int t1;//estimation of increase time onset
  int t2;// estimation of start of decrease
  
}

transformed data {
  real tau_1 = 1.0 / (p_incubation - p_preclinical);
  real tau_2 = 1.0 / p_preclinical;
  real q_P = p_q_P;
  real gt = p_generation_time;
  real x_r[6+K*K+K]; // 6 parameters + K*K contact matrix parameters + K age_dist parameters
  int x_i[1] = {K};
  real init[K*6] = rep_array(0.0, K*6); // initial values
  real contact2[K*K] = contact;
  for(i in 1:(2*K)) contact2[i] = contact[i] * p_children_trans; // apply lower transmissibility in children
  x_r[1] = tau_1;
  x_r[2] = tau_2;
  x_r[3] = q_P;
  x_r[4] = gt;
  x_r[5] = t1;//time of increase in Re
  x_r[6] = t2;//time of decrease in Re
  x_r[7:(6+K*K)] = contact2;
  for(k in 1:K) {
    x_r[6+K*K+k] = age_dist[k];
  }
}

parameters{
  real<lower=0,upper=1> beta; // base transmission rate
  real<lower=0,upper=1> eta1; // reduction in transmission rate after incresed relaxations
  real<lower=0,upper=1> eta2; // reduction in transmission rate after quarantine measures
  real<lower=0,upper=30> nu1; //delays from a manually chosen timepoint which is start increase and decrease
  real<lower=0,upper=30> nu2; //delays from a manually chosen timepoint which is start increase and decrease
  
  
  vector<lower=0,upper=1> [K] epsilon; // age-dependent mortality probability
  real<lower=0, upper=1> pi; // number of cases at t0
  real<lower=0> phi[2]; // variance parameters
  real<lower=0,upper=1> xi1; // slope of Re_increase  
  real<lower=0,upper=1> xi2; // slope of Re_decrease
  
 
  real<lower=0,upper=1> psi; // proportion of symptomatics
  real<lower=0,upper=10> shape; // 
  real<lower=0,upper=1> rate; // 
  real<lower=0,upper=1> lambda; //initial slope of the saturation function
}
transformed parameters {
  // change of format for integrate_ode_rk45
  real theta[9]; // vector of parameters
  real y[S,K*6]; // raw ODE output
  vector[K] comp_C[S+G];
  vector[K] comp_diffC[S+G];
  //vector[K] comp_T[S+D_max];//JE
  vector[K] comp_diffT[S+D_max];//JE
  vector[K] comp_M[S+G];
  vector[K] comp_diffM[S+G];
  vector[D_max] p_delay;
  // outcomes
  vector[K] output_incidence_cases_age[D]; // overall case incidence by day and age group
  vector[D] output_incidence_cases; // overall case incidence by day
  vector[K] output_cum_cases_age[D]; //cumulative number of cases by age
  vector[D] output_incidence_deaths; // overal mortality incidence by day 
  simplex[K] output_agedistr_cases; // final age distribution of cases
  simplex[K] output_agedistr_deaths; // final age distribution of deaths
  
  // change of format for integrate_ode_rk45
  theta[1:9] = {beta, xi1,xi2,pi,psi,eta1,eta2,nu1,nu2};
  // run ODE solver
  y = integrate_ode_bdf(
    SEIR, // ODE function
    init, // initial states
    t0, // t0
    ts, // evaluation dates (ts)
    theta, // parameters
    x_r, // real data
    x_i, // integer data
    1.0E-10, 1.0E-10, 1.0E3); // tolerances and maximum steps
    
  // extract and format ODE results (1.0E-9 correction to avoid negative values due to unprecise estimates of zeros as tolerance is 1.0E-10)
  for(i in 1:S) {
    comp_C[i] = (to_vector(y[i,(5*K+1):(6*K)]) + 1.0E-9) * pop_t;
    comp_diffC[i] = i==1 ? comp_C[i,] : 1.0E-9*pop_t + comp_C[i,] - comp_C[i-1,]; // lagged difference of cumulative incidence of symptomatics
  }
  
  for(i in 1:D_max){
    p_delay[i]=discretize(D_max,shape,rate)[i]; // calculate the discretized fit
  }
  
  // Incidence and cumulative incidence after S
  for(g in 1:G) {
    comp_C[S+g] = comp_C[S];
    comp_diffC[S+g] = rep_vector(1.0E-9,K);
  }
  for(i in 1:(S+D_max)){
    comp_diffT[i] = rep_vector(1.0E-9,K);
  }
  for(i in 1:S){
    for(d in 0:(D_max-1)){
      comp_diffT[i+d] += comp_diffC[i] * p_delay[d+1]; //People are tested at onset plus 
    }
  }
  
  // Mortality
  for(i in 1:(S+G)){
    comp_diffM[i] = rep_vector(1.0E-9,K);
  }
  for(i in 1:S){
    for(g in 1:G){
      comp_diffM[i+g] += comp_diffC[i] .* epsilon * p_gamma[g]; //new deaths at i+g are new cases at i*mortality*delay_prob for g
    }
  }
  for(i in 1:(S+G)){
    for(k in 1:K){
      comp_M[i,k] = sum(comp_diffM[1:i,k]);// Compute outcomes
    }
  }
  
  for(i in t_data:S){
    output_incidence_cases_age[i-t_data+1] = to_vector(tests_age_groups[i-t_data+1]) ./ ( to_vector(tests_age_groups[i-t_data+1]) ./ to_vector(comp_diffT[i]) + 1/lambda );
    output_incidence_cases[i-t_data+1] = sum(output_incidence_cases_age[i-t_data+1]);
    for(k in 1:K){
      output_cum_cases_age[i-t_data+1,k] = sum(output_incidence_cases_age[1:(i-t_data+1),k]);// Compute outcomes
    }
    output_incidence_deaths[i-t_data+1] = sum(comp_diffM[i]);
  }
  
  output_agedistr_cases = output_cum_cases_age[D,] ./ sum(output_cum_cases_age[D,]);
  output_agedistr_deaths = (comp_M[D,]) ./ sum(comp_M[D,]);

}


model {
  // priors
  beta ~ beta(p_beta,p_beta);
  eta1 ~ beta(p_eta1eta2[1],p_eta1eta2[2]);// draw the start and end levels of the sigmoids p_Rleves has prior parameters for all of those
  eta2 ~ beta(p_eta1eta2[3],p_eta1eta2[4]);// draw the start and end levels of the sigmoids

  for(k in 1:K){
    epsilon[k] ~ beta(p_epsilon[1],p_epsilon[2]);
  }

  pi ~ beta(p_pi[1],p_pi[2]); // p_pi=C(1,999)
  phi ~ exponential(p_phi);
  xi1 ~ beta(1,1); //draw for slopes xi1 and xi2 of the modulate_Re function
  xi2 ~ beta(1,1); 
  nu1 ~ exponential(p_nu1nu2[1]); // pnu1nu2=c(1/20,1/30)
  nu2 ~ exponential(p_nu1nu2[2]);
  lambda ~ beta(1,1);
  

  psi ~ beta(p_psi[1],p_psi[2]);
  shape ~ exponential(1/mean_shape);
  rate ~ exponential(1/mean_rate);
  

  // likelihood
  if (inference!=0) {
    for(i in 1:D) {
      target += neg_binomial_2_lpmf( incidence_cases[i] | output_incidence_cases[i], output_incidence_cases[i]/phi[1]);
      target += neg_binomial_2_lpmf( incidence_deaths[i] | output_incidence_deaths[i], output_incidence_deaths[i]/phi[2]);
    }
    target += multinomial_lpmf(agedistr_cases | output_agedistr_cases);
    target += multinomial_lpmf(agedistr_deaths | output_agedistr_deaths);
  }
}

1 Like

I don’t have a lot of time to go through the code, unfortunately, but I have been working with some larger ODE systems recently, and you really need to:

  1. Start with simulated data, and try to recover the parameter values. Use enough simulated data so that the model can identify a trend in the ODE system so that the parameters are identifiable.

  2. Start simple and build up slowly. Add one parameter at a time to the model/estimate one additional parameter at a time.

  3. Ensure that all parameters are being estimated on a scale easy for the HMC algorithm. I usually estimate all parameters on the order of unity with a normal(0, 1) prior, and transform parameters to the appropriate order back in the transformed parameters block. For instance, if I have a per-capita transmission rate that is a = 0.000001, then I’d write the prior as a_raw ~ normal(0, 1) on the transmission rate and pass it to the integrator as a = a_raw / 1e6 or something.

4 Likes

+1 to all what @cgoold wrote. We tried to summarize a lot of those strategies for handling misbehaving. Models at Divergent transitions - a primer

ODE models can be hard to get right, so best of luck!

Thanks to both of you! That rises a general question:
If in a prior predictive check all the parameter space gets sampled as it should, but the chains are stuck in the initial values if i use the inference, then this does not necessarily mean that the model is wrong/there is a coding issue. It can also be that the parameter space is too complex, right? And then the solution would be to do what? Re-parametrize? Or would it make sense to use different step sizes and so on?

Unfortunately, no general answers here. Yes, it does only mean that something is wrong, but understanding where and what often requires delving deep into the math and structure of the model… or a lot of experimentation by e.g. adding one parameter at a time into unknowns as suggested above.

1 Like

I agree with @martinmodrak. The fact that the prior predictive check is running without issue is a good sign in general in my experience (many prior predictive checks return divergent transitions themselves), but if the chains are becoming stuck it implies it’s hard to sample from the posterior distribution still.

As an example, I have been working on coupling some food systems models to disease transmission models. The SIR epidemic-type model obviously has a disease outbreak when \mathcal{R}_0 > 1, but the disease inevitably dies out due to permanent immunity. In this case, the model produces an epidemic peak and a sharp transient change in certain food system outputs in response to the epidemic peak. When trying to fit this model to simulated data, the chains can become stuck because characterising this transient, sharp change in the food system output is difficult. When I include a lot of simulated data, however, the behaviour of the food system model when not in an epidemic dominates the system (i.e. the epidemic only lasts a short period of time relative to the length of simulated data), and the most plausible food system trajectory is one unaffected by the epidemic.

1 Like

Thanks to both of you. After hours, i found out that my R_0 can get lower or equal than zero for some unlucky parameter combinations, which obviously shouldnt be, and generated exceptions in the PPC. However, i dont really understand how this prevents the algorithm from finding better values, since negative case numbers obviously wont generate a good likelihood…

However, the model does not exceed treedepth anymore, and therefore is much faster, even if the model itself does not yet reproduce the data, but i think now its a problem of my modelling and not stan-related anymore!

3 Likes