Hello,
I am working with a 27-dimensional ODE model that has step functions of parameters. The parameter vector has 13 dimensions. The log-likelihood (or cost) function depends on output of the ODE model.
I would like to perform Bayesian Inference with NUTS in Stan, where the partial derivative of the log-likelihood with respect to each parameter is computed using adjoint sensitivity analysis.
Given the specifications above, is Stan able to compute the sensitivities and output them at each iteration?
If not, does Stan support user-defined custom gradients, where I could code up the adjoint sensitivity method myself? I could not find clear documentation or examples of this.
My latest attempt at a Stan model file (using ode_adjoint_tol_ctl) is provided in code below.
Thanks in advance for any help.
functions {
vector rhs(real t, vector y, vector theta) {
vector[27] dydt;
dydt[1] = (-theta[5]*y[1]*(y[20] + 1.1*(y[5]+y[8]+y[11]+y[14])+0.9*y[17] + 0.1*(y[21] + 1.1*(y[6]+y[9]+y[12]+y[15])+0.9*y[18]))/19216182.0 - ((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[1]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[2]))*int_step(t-theta[1]);
dydt[2] = (-theta[5]*0.1*y[2]*(y[20] + 1.1*(y[5]+y[8]+y[11]+y[14])+0.9*y[17] + 0.1*(y[21] + 1.1*(y[6]+y[9]+y[12]+y[15])+0.9*y[18]))/19216182.0 + ((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[1]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[2]))*int_step(t-theta[1]);
dydt[3] = (theta[5]*y[1]*(y[20] + 1.1*(y[5]+y[8]+y[11]+y[14])+0.9*y[17] + 0.1*(y[21] + 1.1*(y[6]+y[9]+y[12]+y[15])+0.9*y[18]))/19216182.0 - 0.94*y[3]-((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[3]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[4]))*int_step(t-theta[1]);
dydt[4] = (theta[5]*0.1*y[2]*(y[20] + 1.1*(y[5]+y[8]+y[11]+y[14])+0.9*y[17] + 0.1*(y[21] + 1.1*(y[6]+y[9]+y[12]+y[15])+0.9*y[18]))/19216182.0 - 0.94*y[4]+((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[3]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[4]))*int_step(t-theta[1]);
dydt[5] = (0.94*(y[3]-y[5])-0.00380*y[5]-((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[5]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[6]))*int_step(t-theta[1]);
dydt[6] = (0.94*(y[4]-y[6])-0.00380*y[6]+((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[5]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[6]))*int_step(t-theta[1]);
dydt[7] = (0.00380*(y[5]+y[6])-0.94*y[7])*int_step(t-theta[1]);
dydt[8] = (0.94*(y[5]-y[8])-0.00380*y[8]-((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[8]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[9]))*int_step(t-theta[1]);
dydt[9] = (0.94*(y[6]-y[9])-0.00380*y[9]+((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[8]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[9]))*int_step(t-theta[1]);
dydt[10] = (0.00380*(y[8]+y[9])-0.94*y[10])*int_step(t-theta[1]);
dydt[11] = (0.94*(y[8]-y[11])-0.00380*y[11]-((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[11]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[12]))*int_step(t-theta[1]);
dydt[12] = (0.94*(y[9]-y[12])-0.00380*y[12]+((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[11]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[12]))*int_step(t-theta[1]);
dydt[13] = (0.00380*(y[11]+y[12])-0.94*y[13])*int_step(t-theta[1]);
dydt[14] = (0.94*(y[11]-y[14])-0.00380*y[14]-((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[14]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[15]))*int_step(t-theta[1]);
dydt[15] = (0.94*(y[12]-y[15])-0.00380*y[15]+((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[14]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[15]))*int_step(t-theta[1]);
dydt[16] = (0.00380*(y[14]+y[15])-0.94*y[16])*int_step(t-theta[1]);
dydt[17] = (0.44*0.94*y[14]-0.00380*y[17]-((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[17]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[18])-0.26*y[17])*int_step(t-theta[1]);
dydt[18] = (0.44*0.94*y[15]-0.00380*y[18]+((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[17]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[18])-0.26*y[18])*int_step(t-theta[1]);
dydt[19] = (0.44*0.94*y[16]+0.00380*(y[17]+y[18])-0.26*y[19])*int_step(t-theta[1]);
dydt[20] = ((1-0.44)*0.94*y[14]-(0.00380+0.4)*y[20]-((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[20]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[21])-(0.00648+0.11352)*y[20])*int_step(t-theta[1]);
dydt[21] = ((1-0.44)*0.94*y[15]-(0.00380+0.4)*y[21]+((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[20]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[21])-(0.00648+0.11352)*y[21])*int_step(t-theta[1]);
dydt[22] = ((1-0.44)*0.94*y[16]+(0.00380+0.4)*(y[21]+y[20])-(0.00648+0.11352)*y[22])*int_step(t-theta[1]);
dydt[23] = (0.00648*(y[20]+y[21]+y[22])-(0.1343+0.0357)*y[23])*int_step(t-theta[1]);
dydt[24] = (0.26*y[17]+0.11352*y[20]+0.1343*y[23]-((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[24]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[25]))*int_step(t-theta[1]);
dydt[25] = (0.26*(y[18]+y[19])+0.11352*(y[21]+y[22])+((theta[6]*theta[7]*int_step(t-theta[1]-theta[2])+(theta[8]*theta[9]-theta[6]*theta[7])*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*theta[11]-theta[8]*theta[9])*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[24]-(theta[6]*(1-theta[7])*int_step(t-theta[1]-theta[2])+(theta[8]*(1-theta[9])-theta[6]*(1-theta[7]))*int_step(t-theta[1]-theta[2]-theta[3])+(theta[10]*(1-theta[11])-theta[8]*(1-theta[9]))*int_step(t-theta[1]-theta[2]-theta[3]-theta[4]))*y[25]))*int_step(t-theta[1]);
dydt[26] = (0.0357*y[23])*int_step(t-theta[1]);
dydt[27] = ((1-0.44)*0.94*(y[14]+y[15]))*int_step(t-theta[1]);
return dydt;
}
}
data {
int n_days;
array[n_days] int cases;
array[n_days+1] real tSpanSimulation;
vector[13] L; // lower bounds
vector[13] U; // upper bounds
}
parameters{
vector<lower=L, upper=U>[13] theta;
}
model {
array[n_days] real predicted;
theta[1] ~ uniform(0,600);
theta[2] ~ uniform(0,600);
theta[3] ~ uniform(0,600);
theta[4] ~ uniform(0,600);
theta[5] ~ uniform(0,10);
theta[6] ~ uniform(0,10);
theta[7] ~ uniform(0,1);
theta[8] ~ uniform(0,10);
theta[9] ~ uniform(0,1);
theta[10] ~ uniform(0,10);
theta[11] ~ uniform(0,1);
theta[12] ~ uniform(0,1);
theta[13] ~ uniform(0,100);
for (i in 1:n_days)
{
if (predicted[i] >= 0)
{
target += lgamma(cases[i]+theta[13])-lgamma(cases[i]+1)-lgamma(theta[13])+theta[13]*log(theta[13]/(theta[13]+predicted[i]))+cases[i]*log(predicted[i]/(theta[13]+predicted[i]));
}
}
}
generated quantities {
vector[27] init;
array[n_days] real predicted;
array[n_days] real output;
init[1] = 19216182;
init[2] = 0;
init[3] = 0;
init[4] = 0;
init[5] = 0;
init[6] = 0;
init[7] = 0;
init[8] = 0;
init[9] = 0;
init[10] = 0;
init[11] = 0;
init[12] = 0;
init[13] = 0;
init[14] = 0;
init[15] = 0;
init[16] = 0;
init[17] = 0;
init[18] = 0;
init[19] = 0;
init[20] = 1;
init[21] = 0;
init[22] = 0;
init[23] = 0;
init[24] = 0;
init[25] = 0;
init[26] = 0;
init[27] = 1;
array[n_days+1] vector[27] y = ode_adjoint_tol_ctl(rhs, init, 0.0, tSpanSimulation,
1e-8, // forward tolerance
rep_vector(1e-8, 27), // forward tolerance
1e-8, // backward tolerance
rep_vector(1e-8, 27), // backward tolerance
1e-8, // quadrature tolerance
1e-8, // quadrature tolerance
1000000,
1000000, // number of steps between checkpoints
1, // interpolation polynomial: 1=Hermite, 2=polynomial
2, // solver for forward phase: 1=Adams, 2=BDF
2, // solver for backward phase: 1=Adams, 2=BDF
theta);
for (i in 1:n_days)
{
predicted[i] = theta[12] * (y[i+1, 27] - y[i, 27]);
output[i] = neg_binomial_2_rng(predicted[i]+1e-10, theta[13]);
}
}