# Code review of slow hierarchical (ode) model

Hi all,

I’m having trouble getting a reasonable speed for an ode model I’m working on, and it’s throwing a lot of divergent-transition warnings.
I have a version that runs at a reasonably ok speed (~10m for 500 samples) for a version that fits a single ode curve. I’ve since tried expanding it to fit multiple ODE curves (usually 4-5), assuming that the parameters for the ODE curves are related to each other in a hierarchical fashion, using a mv-normal distribution. However, this version takes multiple hours to run with a similar number of samples, and tends to have trouble converging at all.
I’m hoping that it’s just that I’ve done something stupid with the parametrization of the model - if any one a bit more experienced could have a look over the stan code below, I’d be much grateful.

``````//Taken from https://jrmihalj.github.io/estimating-transmission-by-fitting-mechanistic-models-in-Stan/

functions {

##An ODE function for generating an epidemic curve given four parameters
real[] SEIR(
real t,       // time
real[] y,      // state
real[] theta,  // parameters
data real[] x_r,    // data (real)
data int[] x_i) {   // data (integer)

real dydt[6];

// state
real S;
real E;
real I_symp;
real I_asymp;
real R1;
real R2;

// parameters
real epsilon;
real p;
real theta_local;

// fixed parameters for the model
real p_lower_inf;  // lower infection for asymptomatic ?
real eta;  //1 / latency phase
real p_symp; // probability of being symptomatic
real gammaD;  // 1/length of infectiousness
real gamma_pos;  // 1/length of being positive

real N;
real t_today;

real b_t;

S = y[1];
E = y[2];
I_symp = y[3];
I_asymp = y[4];
R1 = y[5];
R2 = y[6];

p_lower_inf= x_r[1];
eta= x_r[2];
gamma_pos= x_r[4];

t_today=x_i[1];
N = x_i[2];

p=theta[1];
epsilon=theta[2];
theta_local=theta[3];
p_symp= theta[4];

b_t = ((1-p)/(1+exp(-epsilon*((t-t_today))))+p)*theta_local;

dydt[1] = -b_t * S * I_symp/N - p_lower_inf*b_t * S * I_asymp/N;
dydt[2] =  b_t * S * I_symp/N + p_lower_inf*b_t * S * I_asymp/N - eta*E;
dydt[3] =  p_symp * eta * E      - gammaD * I_symp;
dydt[4] =  (1 - p_symp)* eta * E - gammaD * I_asymp;
dydt[5] = gammaD * (I_symp + I_asymp) - gamma_pos * R1;
dydt[6] =  gamma_pos * R1;

return dydt;
}

}

data {
int<lower=1> nRegions;
int<lower=1> maxObs;

int<lower=1> nObs[nRegions]; // number of observations per region
real i0[nRegions]; // starting (observed) incidence

real  t0[nRegions];    // starting time
matrix[maxObs,nRegions] ts; //time points for observations
matrix[maxObs,nRegions] incidence;   // observed incidence values over time

// fixed parameters for the model
real p_lower_inf;  // lower infection for asymptomatic ?
real eta;  //1 / latency phase
real gammaD;  // 1/length of infectiousness
real gamma_pos;  // 1/length of being positive

int t_today;  // time point of lockdown
int N[nRegions];       // Population size

// Data for surveys

int n_surveys;
int survey_counts[2,n_surveys];
int survey_t[2,n_surveys];
int survey_regions[n_surveys];
}
transformed data {
vector[4] theta_fixed[nRegions];
int theta_int[nRegions,2];

for(r in 1:nRegions){
theta_fixed[r][1] = p_lower_inf;
theta_fixed[r][2] = eta;
theta_fixed[r][4] = gamma_pos;

theta_int[r,1] = t_today;
theta_int[r,2] = N[r];
}

}
parameters {
// following https://mc-stan.org/docs/2_19/stan-users-guide/multivariate-hierarchical-priors-section.html

cholesky_factor_corr[4] Omega;
vector<lower=0>[4] tau;

vector[4] gamma;           // group coeffs

vector[4] theta_untransformed[nRegions];

real<lower=0> sigma;

}

transformed parameters{

vector[4] theta[nRegions];
vector[6] y0[nRegions]; // starting state of the SEIR
matrix[maxObs,6] y_hat[nRegions];   // S,E,Is,Ia,R1,R2 values over time
vector[maxObs] inc_hat[nRegions];

real Pos_hat_surveys[n_surveys];

vector[maxObs] Pos_hat[nRegions];  // estimated number that would test positive in a survey.
vector[maxObs] sq_err[nRegions];

matrix[4,4] sigma_pars;

for(r in 1:nRegions){
theta[r][1]=inv_logit(theta_untransformed[r][1]); //p
theta[r][2]=theta_untransformed[r][2]; //epsilon
theta[r][3]=exp(theta_untransformed[r][3]); //theta_local
theta[r][4]=inv_logit(theta_untransformed[r][4]); //p_symp

y0[r][1]= (N[r] - i0[r]*(1 + (1-theta[r][4])/theta[r][4]));
y0[r][2] = 0;
y0[r][3] = i0[r];
y0[r][4] = i0[r]*(1-theta[r][4])/theta[r][4];
y0[r][5] = 0;
y0[r][6] = 0;

y_hat[r][1:nObs[r],1:6] = to_matrix(integrate_ode_rk45(SEIR, to_array_1d(y0[r]), t0[r], to_array_1d(ts[1:nObs[r],r]), to_array_1d(theta[r]),
to_array_1d(theta_fixed[r][1:2]), to_array_1d(theta_int[r])));

Pos_hat[r][1:nObs[r]]    =  to_vector(y_hat[r][1:nObs[r],3])+ to_vector(y_hat[r][1:nObs[r],4]) +
to_vector(y_hat[r][1:nObs[r],5]);

inc_hat[r][1:nObs[r]]= (eta*theta[r][4])*to_vector(y_hat[r][1:nObs[r],2]);

for(t in 1:nObs[r]){
sq_err[r][t]=(incidence[t,r] - inc_hat[r][t])^2;

}

}

sigma_pars = diag_pre_multiply(tau, Omega);

for(s in 1:n_surveys){
Pos_hat_surveys[s] =  mean(Pos_hat[survey_regions[s]][survey_t[1,survey_regions[s]]:survey_t[2,survey_regions[s]]]);
}

}
model {

//using cholesky decomposition per
//https://discourse.mc-stan.org/t/trouble-with-prior-selection-for-multivariate-normal-inverse-wishart-analysis/6088/13

tau ~ cauchy(0, 2.5);
Omega ~ lkj_corr_cholesky(2);

to_vector(gamma) ~ normal(0, 5);
print(tau);
for(r in 1:nRegions){
theta_untransformed[r] ~ multi_normal_cholesky(gamma, sigma_pars);
}

for(r in 1:nRegions){

target += - nObs[r] * log(sum(to_vector(sq_err[r][1:nObs[r]])))/2;
}

for(s in 1:n_surveys){
survey_counts[1,s] ~ binomial(survey_counts[2,s],Pos_hat_surveys[s]/N[survey_regions[s]]);
}

}
generated quantities {
vector<lower=0,upper=1>[nRegions] p;        //  for infectivity function
vector[nRegions] epsilon;  //  for infectivity function
vector[nRegions] theta_local;  //  for infectivity function - on the whole real line
vector<lower=0,upper=1>[nRegions] p_symp;  // proportion of symptomatic cases

for(r in 1:nRegions){
p[r]=theta[r][1] ;
epsilon[r]=theta[r][2] ;
theta_local[r]=theta[r][3] ;
p_symp[r]=theta[r][4] ;
}
}

``````
1 Like

It seems that I can’t edit my above post anymore.

In case it helps anyone, I’ve included the data I’m using at the bottom, together with an initial value function. These initial values are close to “true” to the best of my knowledge, but Stan still have huge issues sampling - it can take up to half an hour just to get 20 samples, likely due to most proposals being rejected. I’ve tried adjusting tolerance to 1e-3, and that helps a bit but it’s still very very slow.

Typical examples of error messages are like the following (the binomial error ones is likely due to over/undershoot of the ode, such that the compartments have negative values or larger values than the total population - they disappear after a while):

``````Chain 4 Rejecting initial value:
Chain 4   Gradient evaluated at the initial value is not finite.
Chain 4   Stan can't start sampling from this initial value.
Chain 1 Exception: integrate_ode_rk45:  Failed to integrate to next output time (53) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 1 Exception: integrate_ode_rk45:  Failed to integrate to next output time (53) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Warning: Chain 1 finished unexpectedly!

Chain 4 Exception: integrate_ode_rk45:  Failed to integrate to next output time (56) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 4 Exception: integrate_ode_rk45:  Failed to integrate to next output time (56) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 4 Exception: integrate_ode_rk45:  Failed to integrate to next output time (56) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 4 Exception: integrate_ode_rk45:  Failed to integrate to next output time (56) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Warning: Chain 4 finished unexpectedly!

Chain 2 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Chain 2 Exception: integrate_ode_rk45:  Failed to integrate to next output time (67) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 2 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
Chain 2 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.
Chain 2
Chain 3 Exception: integrate_ode_rk45:  Failed to integrate to next output time (106) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 3 Exception: integrate_ode_rk45:  Failed to integrate to next output time (106) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Warning: Chain 3 finished unexpectedly!

Chain 2 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Chain 2 Exception: binomial_lpmf: Probability parameter is -1.28842e-10, but must be in the interval [0, 1] (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 212, column 2 to column 92)
Chain 2 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
Chain 2 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.
Chain 2
Chain 2 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Chain 2 Exception: integrate_ode_rk45:  Failed to integrate to next output time (55) in less than max_num_steps steps (in 'C:/Users/Owner/AppData/Local/Temp/RtmpmEqb6q/model-8b6869886312.stan', line 77, column 4 to column 16)
Chain 2 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,
Chain 2 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.
``````

R code for reading in data and running the model in the first post (assumed to be saved as modfile).

``````dat<-list(nRegions = 3, nObs = c(109L, 96L, 109L), maxObs = 109L,
i0 = c(1L, 1L, 1L), t0 = c(48, 61, 48), ts = structure(c(49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94,
95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73,
74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88,
89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102,
103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114,
115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
151, 152, 153, 154, 155, 156, 157, Inf, Inf, Inf, Inf, Inf,
Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134,
135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157), .Dim = c(109L,
3L)), incidence = structure(c(0, 0, 0, 1, 1, 0, 1, 0, 1,
2, 6, 28, 5, 3, 5, 6, 12, 10, 29, 26, 29, 31, 44, 46, 45,
32, 55, 60, 58, 78, 81, 118, 70, 70, 102, 95, 119, 104, 156,
97, 98, 137, 162, 180, 143, 118, 104, 90, 124, 132, 132,
124, 101, 81, 95, 100, 109, 112, 104, 115, 73, 78, 98, 91,
92, 79, 87, 81, 58, 83, 83, 78, 86, 77, 57, 62, 60, 74, 85,
74, 81, 35, 58, 74, 57, 77, 60, 61, 39, 48, 75, 78, 63, 45,
60, 47, 49, 73, 77, 75, 58, 84, 40, 63, 69, 63, 65, 76, 84,
0, 0, 1, 0, 1, 0, 2, 1, 5, 2, 6, 4, 5, 3, 2, 7, 4, 3, 11,
1, 7, 7, 10, 3, 9, 11, 8, 9, 11, 7, 9, 5, 7, 3, 6, 12, 6,
10, 8, 8, 5, 7, 7, 7, 7, 6, 8, 3, 4, 10, 6, 2, 4, 7, 5, 0,
5, 5, 3, 2, 3, 3, 0, 5, 4, 3, 1, 3, 3, 1, 3, 3, 1, 4, 2,
1, 5, 4, 2, 3, 1, 2, 4, 3, 2, 1, 3, 1, 4, 6, 4, 1, 1, 6,
3, 2, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf, Inf,
Inf, Inf, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 3, 1, 2, 0, 6,
2, 3, 2, 5, 2, 4, 4, 6, 11, 8, 4, 11, 6, 16, 16, 16, 6, 6,
14, 14, 13, 11, 14, 17, 20, 19, 15, 37, 22, 23, 27, 29, 23,
29, 27, 37, 28, 36, 28, 31, 26, 29, 34, 32, 25, 26, 32, 23,
27, 31, 25, 30, 15, 29, 27, 27, 41, 23, 24, 21, 30, 27, 23,
35, 29, 13, 11, 28, 21, 25, 22, 26, 16, 15, 30, 23, 34, 20,
32, 14, 14, 22, 24, 22, 30, 20, 31, 15, 27, 40, 31, 47, 32
), .Dim = c(109L, 3L)), N = c(2374550, 287795, 1724529),
survey_counts = structure(c(18, 707, 16, 679), .Dim = c(2L,
2L)), survey_t = structure(c(40, 47, 65, 68), .Dim = c(2L,
2L)), survey_regions = c(1, 1), n_surveys = 2L, p_lower_inf = 1,
eta = 0.196078431372549, gammaD = 0.2, gamma_pos = 0.2, t_today = 76)

init<-function(nRegions){

function(chain_id){ # guesses for the optimisation
p_lower_inf=1
#transformed
u_p <- runif(nRegions, 0.18, 0.22) #
u_e <- runif(nRegions,-0.28,-0.22 )    #
u_t <- runif(nRegions, 0.92, 0.98)   #
#if(p_lower_inf >= 0.5){ u_t <- runif(nRegions, 0, 2) }
#if(p_lower_inf >= 0.8){ u_t <- runif(nRegions, 0, 1) }
u_pb <- rbeta(nRegions, 4, 190)# prob reported
sigma <-log(1+runif(nRegions,0,1)) # Sd around the mean incidence

theta_untransformed<-lapply(1:nRegions,function(x){
c(logit(u_p[x]),u_e[x],log(u_t[x]),logit(u_pb[x]))
})
return(list(theta_untransformed=theta_untransformed,
gamma=theta_untransformed[[1]],
sigma=sigma,
tau=(abs(theta_untransformed[[1]])/4)))
}}
seir_stan<-cmdstan_model(modfile)
seir_fit<-seir_stan\$sample(data=dat,iter_warmup=iter_warmup,
iter_sampling=iter_sampling,
chains=chains,
seed=seed,
refresh = NULL,
thin=1,
max_treedepth = NULL,
init=init(dat\$nRegions))
``````

I honestly don’t see anything immediately problematic, but I also don’t think I understand the model well.

So just to be clear that I understand you well: you are able to fit each of the 5 curves individually but not together?

My guess would be that the way you transform the `theta_untransformed` to `theta` and the priors on the relevant parameters mean that the sampler can - at least initially - explore weird parts of the posterior (inits can mitigate this, but not completely).

An obvious thing would be to switch to a stiff solver (`integrate_ode_bdf`). also consider using the new ODE interface (https://mc-stan.org/docs/2_25/functions-reference/functions-ode-solver.html) I think there some changes to how tolerances are handled and it is not impossible it could help you.

I would also definitely look at some `pairs` plots for some subsets of the parameters (one theta across regions / all thetas within region / Omega and tau…)

There are also a bunch of other posts on SEIR models (https://discourse.mc-stan.org/search?q=seir) so if you haven’t walked through them already, it might help.

Best of luck with you model!

What @martinmodrak said, try new interface.

Why do you have `Inf` in your `ts` data?