Code review of slow hierarchical (ode) model

Hi all,

I’m having trouble getting a reasonable speed for an ode model I’m working on, and it’s throwing a lot of divergent-transition warnings.
I have a version that runs at a reasonably ok speed (~10m for 500 samples) for a version that fits a single ode curve. I’ve since tried expanding it to fit multiple ODE curves (usually 4-5), assuming that the parameters for the ODE curves are related to each other in a hierarchical fashion, using a mv-normal distribution. However, this version takes multiple hours to run with a similar number of samples, and tends to have trouble converging at all.
I’m hoping that it’s just that I’ve done something stupid with the parametrization of the model - if any one a bit more experienced could have a look over the stan code below, I’d be much grateful.

//Taken from https://jrmihalj.github.io/estimating-transmission-by-fitting-mechanistic-models-in-Stan/

functions {
 
  ##An ODE function for generating an epidemic curve given four parameters
  real[] SEIR(
           real t,       // time
           real[] y,      // state
           real[] theta,  // parameters
           data real[] x_r,    // data (real)
           data int[] x_i) {   // data (integer)
            
            
    
    real dydt[6];
    
    // state
    real S;
    real E;
    real I_symp;
    real I_asymp;
    real R1;
    real R2;
    
    // parameters
    real epsilon;
    real p;
    real theta_local;
    
    
    // fixed parameters for the model
    real p_lower_inf;  // lower infection for asymptomatic ? 
    real eta;  //1 / latency phase 
    real p_symp; // probability of being symptomatic
    real gammaD;  // 1/length of infectiousness
    real gamma_pos;  // 1/length of being positive

    real N;
    real t_today;
    
    
    real b_t;
    
    S = y[1];
    E = y[2];
    I_symp = y[3];
    I_asymp = y[4];
    R1 = y[5];
    R2 = y[6];
    
    
    p_lower_inf= x_r[1];
    eta= x_r[2];
    gammaD= x_r[3];
    gamma_pos= x_r[4];
    
    t_today=x_i[1];
    N = x_i[2];
    
    p=theta[1];
    epsilon=theta[2];
    theta_local=theta[3];
    p_symp= theta[4];

    b_t = ((1-p)/(1+exp(-epsilon*((t-t_today))))+p)*theta_local;
      
    
    dydt[1] = -b_t * S * I_symp/N - p_lower_inf*b_t * S * I_asymp/N;
    dydt[2] =  b_t * S * I_symp/N + p_lower_inf*b_t * S * I_asymp/N - eta*E;
    dydt[3] =  p_symp * eta * E      - gammaD * I_symp;
    dydt[4] =  (1 - p_symp)* eta * E - gammaD * I_asymp;
    dydt[5] = gammaD * (I_symp + I_asymp) - gamma_pos * R1;
    dydt[6] =  gamma_pos * R1;
    
    return dydt;
  }
  
}

data {
  int<lower=1> nRegions;
  int<lower=1> maxObs;
  
  int<lower=1> nObs[nRegions]; // number of observations per region
  real i0[nRegions]; // starting (observed) incidence
  
  real  t0[nRegions];    // starting time
  matrix[maxObs,nRegions] ts; //time points for observations
  matrix[maxObs,nRegions] incidence;   // observed incidence values over time

   // fixed parameters for the model
  real p_lower_inf;  // lower infection for asymptomatic ? 
  real eta;  //1 / latency phase 
  real gammaD;  // 1/length of infectiousness
  real gamma_pos;  // 1/length of being positive

  int t_today;  // time point of lockdown
  int N[nRegions];       // Population size
  
  
  
  // Data for surveys 
  
  int n_surveys;
  int survey_counts[2,n_surveys];
  int survey_t[2,n_surveys];
  int survey_regions[n_surveys];
}
transformed data {
  vector[4] theta_fixed[nRegions];
  int theta_int[nRegions,2];
    
  for(r in 1:nRegions){
    theta_fixed[r][1] = p_lower_inf;
    theta_fixed[r][2] = eta;
    theta_fixed[r][3] = gammaD;
    theta_fixed[r][4] = gamma_pos;
    
    theta_int[r,1] = t_today;
    theta_int[r,2] = N[r];
  }
  
}
parameters {
    // following https://mc-stan.org/docs/2_19/stan-users-guide/multivariate-hierarchical-priors-section.html   
    
    cholesky_factor_corr[4] Omega;
    vector<lower=0>[4] tau;
                                                                               
    vector[4] gamma;           // group coeffs

    vector[4] theta_untransformed[nRegions];
    
    real<lower=0> sigma;
 
}

transformed parameters{
  
    vector[4] theta[nRegions];
    vector[6] y0[nRegions]; // starting state of the SEIR
    matrix[maxObs,6] y_hat[nRegions];   // S,E,Is,Ia,R1,R2 values over time
    vector[maxObs] inc_hat[nRegions];
    
    real Pos_hat_surveys[n_surveys];
    
    vector[maxObs] Pos_hat[nRegions];  // estimated number that would test positive in a survey. 
    vector[maxObs] sq_err[nRegions];
   
    matrix[4,4] sigma_pars;
   

    for(r in 1:nRegions){
    theta[r][1]=inv_logit(theta_untransformed[r][1]); //p
    theta[r][2]=theta_untransformed[r][2]; //epsilon
    theta[r][3]=exp(theta_untransformed[r][3]); //theta_local
    theta[r][4]=inv_logit(theta_untransformed[r][4]); //p_symp
   
   
    y0[r][1]= (N[r] - i0[r]*(1 + (1-theta[r][4])/theta[r][4]));
    y0[r][2] = 0;
    y0[r][3] = i0[r];
    y0[r][4] = i0[r]*(1-theta[r][4])/theta[r][4]; 
    y0[r][5] = 0;
    y0[r][6] = 0;
          
    y_hat[r][1:nObs[r],1:6] = to_matrix(integrate_ode_rk45(SEIR, to_array_1d(y0[r]), t0[r], to_array_1d(ts[1:nObs[r],r]), to_array_1d(theta[r]), 
    to_array_1d(theta_fixed[r][1:2]), to_array_1d(theta_int[r])));

    Pos_hat[r][1:nObs[r]]    =  to_vector(y_hat[r][1:nObs[r],3])+ to_vector(y_hat[r][1:nObs[r],4]) +
                                to_vector(y_hat[r][1:nObs[r],5]);
    
    inc_hat[r][1:nObs[r]]= (eta*theta[r][4])*to_vector(y_hat[r][1:nObs[r],2]); 
                                                         
    for(t in 1:nObs[r]){
      sq_err[r][t]=(incidence[t,r] - inc_hat[r][t])^2;
      
    }
    
  }
  
   sigma_pars = diag_pre_multiply(tau, Omega);
  
  for(s in 1:n_surveys){
    Pos_hat_surveys[s] =  mean(Pos_hat[survey_regions[s]][survey_t[1,survey_regions[s]]:survey_t[2,survey_regions[s]]]);
        }
        
        
}
model {
  
  //using cholesky decomposition per 
  //https://discourse.mc-stan.org/t/trouble-with-prior-selection-for-multivariate-normal-inverse-wishart-analysis/6088/13
                                                
  tau ~ cauchy(0, 2.5);
  Omega ~ lkj_corr_cholesky(2);
  
  to_vector(gamma) ~ normal(0, 5);
  print(tau);
  for(r in 1:nRegions){      
  theta_untransformed[r] ~ multi_normal_cholesky(gamma, sigma_pars);
  }
  
 for(r in 1:nRegions){
 
  target += - nObs[r] * log(sum(to_vector(sq_err[r][1:nObs[r]])))/2;  
 }
 
  for(s in 1:n_surveys){
  survey_counts[1,s] ~ binomial(survey_counts[2,s],Pos_hat_surveys[s]/N[survey_regions[s]]);
  }
  
 
}
generated quantities {
    vector<lower=0,upper=1>[nRegions] p;        //  for infectivity function
    vector[nRegions] epsilon;  //  for infectivity function
    vector[nRegions] theta_local;  //  for infectivity function - on the whole real line
    vector<lower=0,upper=1>[nRegions] p_symp;  // proportion of symptomatic cases
    
     for(r in 1:nRegions){
     p[r]=theta[r][1] ;
     epsilon[r]=theta[r][2] ;
     theta_local[r]=theta[r][3] ;
     p_symp[r]=theta[r][4] ;
     }
}

1 Like

It seems that I can’t edit my above post anymore.

In case it helps anyone, I’ve included the data I’m using at the bottom, together with an initial value function. These initial values are close to “true” to the best of my knowledge, but Stan still have huge issues sampling - it can take up to half an hour just to get 20 samples, likely due to most proposals being rejected. I’ve tried adjusting tolerance to 1e-3, and that helps a bit but it’s still very very slow.

Typical examples of error messages are like the following (the binomial error ones is likely due to over/undershoot of the ode, such that the compartments have negative values or larger values than the total population - they disappear after a while):

Chain 4 Rejecting initial value:
Chain 4   Gradient evaluated at the initial value is not finite.
Chain 4   Stan can't start sampling from this initial value.
Chain 1 Exception: integrate_ode_rk45:  Failed to integrate to next output time (53) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 1 Exception: integrate_ode_rk45:  Failed to integrate to next output time (53) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Warning: Chain 1 finished unexpectedly!

Chain 4 Exception: integrate_ode_rk45:  Failed to integrate to next output time (56) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 4 Exception: integrate_ode_rk45:  Failed to integrate to next output time (56) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 4 Exception: integrate_ode_rk45:  Failed to integrate to next output time (56) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 4 Exception: integrate_ode_rk45:  Failed to integrate to next output time (56) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Warning: Chain 4 finished unexpectedly!

Chain 2 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Chain 2 Exception: integrate_ode_rk45:  Failed to integrate to next output time (67) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 2 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
Chain 2 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.
Chain 2 
Chain 3 Exception: integrate_ode_rk45:  Failed to integrate to next output time (106) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 3 Exception: integrate_ode_rk45:  Failed to integrate to next output time (106) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Warning: Chain 3 finished unexpectedly!

Chain 2 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Chain 2 Exception: binomial_lpmf: Probability parameter is -1.28842e-10, but must be in the interval [0, 1] (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 212, column 2 to column 92)
Chain 2 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
Chain 2 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.
Chain 2 
Chain 2 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Chain 2 Exception: integrate_ode_rk45:  Failed to integrate to next output time (55) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 2 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
Chain 2 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.

R code for reading in data and running the model in the first post (assumed to be saved as modfile).

dat<-list(nRegions = 3, nObs = c(109L, 96L, 109L), maxObs = 109L, 
    i0 = c(1L, 1L, 1L), t0 = c(48, 61, 48), ts = structure(c(49, 
    50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 
    65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 
    80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 
    95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 
    108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 
    120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 
    132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 
    144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 
    156, 157, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 
    74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 
    89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 
    103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 
    115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 
    127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 
    139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 
    151, 152, 153, 154, 155, 156, 157, Inf, Inf, Inf, Inf, Inf, 
    Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, 49, 50, 51, 52, 53, 
    54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 
    69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 
    84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 
    99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 
    111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 
    123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 
    135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 
    147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157), .Dim = c(109L, 
    3L)), incidence = structure(c(0, 0, 0, 1, 1, 0, 1, 0, 1, 
    2, 6, 28, 5, 3, 5, 6, 12, 10, 29, 26, 29, 31, 44, 46, 45, 
    32, 55, 60, 58, 78, 81, 118, 70, 70, 102, 95, 119, 104, 156, 
    97, 98, 137, 162, 180, 143, 118, 104, 90, 124, 132, 132, 
    124, 101, 81, 95, 100, 109, 112, 104, 115, 73, 78, 98, 91, 
    92, 79, 87, 81, 58, 83, 83, 78, 86, 77, 57, 62, 60, 74, 85, 
    74, 81, 35, 58, 74, 57, 77, 60, 61, 39, 48, 75, 78, 63, 45, 
    60, 47, 49, 73, 77, 75, 58, 84, 40, 63, 69, 63, 65, 76, 84, 
    0, 0, 1, 0, 1, 0, 2, 1, 5, 2, 6, 4, 5, 3, 2, 7, 4, 3, 11, 
    1, 7, 7, 10, 3, 9, 11, 8, 9, 11, 7, 9, 5, 7, 3, 6, 12, 6, 
    10, 8, 8, 5, 7, 7, 7, 7, 6, 8, 3, 4, 10, 6, 2, 4, 7, 5, 0, 
    5, 5, 3, 2, 3, 3, 0, 5, 4, 3, 1, 3, 3, 1, 3, 3, 1, 4, 2, 
    1, 5, 4, 2, 3, 1, 2, 4, 3, 2, 1, 3, 1, 4, 6, 4, 1, 1, 6, 
    3, 2, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, 
    Inf, Inf, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 3, 1, 2, 0, 6, 
    2, 3, 2, 5, 2, 4, 4, 6, 11, 8, 4, 11, 6, 16, 16, 16, 6, 6, 
    14, 14, 13, 11, 14, 17, 20, 19, 15, 37, 22, 23, 27, 29, 23, 
    29, 27, 37, 28, 36, 28, 31, 26, 29, 34, 32, 25, 26, 32, 23, 
    27, 31, 25, 30, 15, 29, 27, 27, 41, 23, 24, 21, 30, 27, 23, 
    35, 29, 13, 11, 28, 21, 25, 22, 26, 16, 15, 30, 23, 34, 20, 
    32, 14, 14, 22, 24, 22, 30, 20, 31, 15, 27, 40, 31, 47, 32
    ), .Dim = c(109L, 3L)), N = c(2374550, 287795, 1724529), 
    survey_counts = structure(c(18, 707, 16, 679), .Dim = c(2L, 
    2L)), survey_t = structure(c(40, 47, 65, 68), .Dim = c(2L, 
    2L)), survey_regions = c(1, 1), n_surveys = 2L, p_lower_inf = 1, 
    eta = 0.196078431372549, gammaD = 0.2, gamma_pos = 0.2, t_today = 76)

init<-function(nRegions){
  
  function(chain_id){ # guesses for the optimisation
    p_lower_inf=1
    #transformed 
    u_p <- runif(nRegions, 0.18, 0.22) #
    u_e <- runif(nRegions,-0.28,-0.22 )    #
    u_t <- runif(nRegions, 0.92, 0.98)   #
    #if(p_lower_inf >= 0.5){ u_t <- runif(nRegions, 0, 2) }
    #if(p_lower_inf >= 0.8){ u_t <- runif(nRegions, 0, 1) }
    u_pb <- rbeta(nRegions, 4, 190)# prob reported
    sigma <-log(1+runif(nRegions,0,1)) # Sd around the mean incidence
    
    theta_untransformed<-lapply(1:nRegions,function(x){
      c(logit(u_p[x]),u_e[x],log(u_t[x]),logit(u_pb[x]))
    })
    return(list(theta_untransformed=theta_untransformed,
                gamma=theta_untransformed[[1]],
                sigma=sigma,
                tau=(abs(theta_untransformed[[1]])/4)))
  }}
seir_stan<-cmdstan_model(modfile)
  seir_fit<-seir_stan$sample(data=dat,iter_warmup=iter_warmup,
                             iter_sampling=iter_sampling,
                             chains=chains,
                             seed=seed,
                             refresh = NULL,
                             thin=1,
                             max_treedepth = NULL,
                             init=init(dat$nRegions))

I honestly don’t see anything immediately problematic, but I also don’t think I understand the model well.

So just to be clear that I understand you well: you are able to fit each of the 5 curves individually but not together?

My guess would be that the way you transform the theta_untransformed to theta and the priors on the relevant parameters mean that the sampler can - at least initially - explore weird parts of the posterior (inits can mitigate this, but not completely).

An obvious thing would be to switch to a stiff solver (integrate_ode_bdf). also consider using the new ODE interface (https://mc-stan.org/docs/2_25/functions-reference/functions-ode-solver.html) I think there some changes to how tolerances are handled and it is not impossible it could help you.

I would also definitely look at some pairs plots for some subsets of the parameters (one theta across regions / all thetas within region / Omega and tau…)

There are also a bunch of other posts on SEIR models (https://discourse.mc-stan.org/search?q=seir) so if you haven’t walked through them already, it might help.

Best of luck with you model!

What @martinmodrak said, try new interface.

Why do you have Inf in your ts data?