Speeding a bioenergetic ODE model

Dear all,

After a hard time, I have implemented in STAN a bioenergetic model for the size and energy dynamics of an animal (a fish in our case). The model (dynamic energy budged model, or DEB) has theoretical support (Dynamic energy budget theory - Wikipedia). It essentially consists in an ODE system of four coupled differential equations with many parameters. Temperature and food are forcing variables. In addition, none of the state variables is directly observable, thus additional equations (and parameters) are needed for connecting state variables with observations (which typically are length, weight, fecundity, ā€¦).

I am interested in estimating the DEB parameters at the fish level. We have several measures along the lifespan of each fish, and many fish (aquaculture data). Thus, the objective is a hierarchical model (i.e., the fish-level DEB parameters are distributed in some way).

The code is at the end.

After simulating data for ten fish (length, weight and fecundity; one simulated observation per year; 12 years lifespan; parameters values, between-fish variability and measurement error were all realistic), I can recover three theoretically relevant DEB parameters when fixing (=assuming that they are known) all the other parameters. No divergent transitions; Rhat close to 1. Accuracy and precision of the estimates for these three target parameters at the fish level are satisfactory four our purpose (excepting for the measurement error of fecundity, for which a bug may remain).

Initial values and priors can be defined based on previous results (About AmP).

The advice we would like to ask today to the forum is how to improve the computation time: 24 hours in a device without memory limitations (10 fishes, 12 observations per fish of three observable variables = 360 observations). Now I would like to start with real data and to play with the many details of the model (e.g., to decide if it is worth to measure additional variables), to include more fish (hundreds) or more observations per fishā€¦ but the long computational time is discouraging.

So, are there drawbacks in the code? Are there any chance to move from days to minutes or at least to a few hours (e.g., changing tolerance or max number of steps in the numerical integration function)?

We recognize that identifiability is also a concern when dealing with this large number of parameters. In our case, after adding a fourth parameter, the posteriors of two parameters became highly correlated. Fortunately, most of the parameters can be assumed to be invariant at the between-fish level and/or can be estimated after ad-hoc experiments. Thus, any comment on identifiability will be very welcome but I would like to focus first on the computation time technical problem.

I run the code with Rstan, with

options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

Thank you
m

The code:

code=" 
functions {
  real[] DEB(
    // input
    real t,         // time
    real[] z,       // state variables
    real[] theta,   // parameters to be estimated
    real[] x_r,     // for passing the fixed parameters (=assumed to be known) to the numerical integrator 
    int[] x_i       // not used but needed for the sintaxis of the numerical integrator
    ){
      // Derivatives
      real dEdt;  // Reserve Energy (j)
      real dVdt;  // Structural length (cm)
      real dHdt;  // Maturation Energy (j)
      real dURdt;  // Reproduction Energy (j)

      // state variables
      real E = z[1];
      real V = z[2];
      real H = z[3];
      real UR = z[4];

      // Fixed parameters (assumed to be known)
      real Kelvin = x_r[1];
      real Temp_mean = x_r[2];
      real Temp_amp = x_r[3];
      real pi2f = x_r[4];
      real Temp_phi = x_r[5];
      real TA = x_r[6];
      real T1 = x_r[7];
      real f = x_r[8];
      real EG = x_r[9];
      real kJ = x_r[10];
      //real kappa = x_r[11]; //Caution: the blocked parameters are those to be estimated 
      //real pAm = x_r[12];
      real kM = x_r[13];
      //real v = x_r[14];
      real Hp = x_r[15];
      
      // parameters to be estimated
      real pAm = theta[1];
      real v = theta[2];
      real logit_kappa = theta[3];
      real kappa;
      
      // Auxiliary variables
      real Temp;         // Temperature
      real cT;           // Arrhenius temperature correction
      real E_V;          // energy density
      real V23;          // surface
      real pAm_T;        // model paramteters after temperature correction
      real v_T;
      real kM_T;
      real kJ_T;
      
      // fluxes (j/day)
      real pA;       // energy assimilation
      real pM;       // simatic maintenace
      real pJ;       // maturity maintenance
      real pC;       // mobilization
      
      // temperature correction and auxiliary variables
      Temp = Kelvin + Temp_mean + Temp_amp*sin(pi2f*t + Temp_phi); // sinusoidal function for temperature
      cT = exp(TA/T1 - TA/Temp);   // temperature correction (Arrhenius)
      E_V = E/V;
      V23 = pow(V,2.0/3.0);        // surface
      pAm_T=cT*pAm;                // paramters after temperature correction 
	    v_T=cT*v;
	    kM_T=cT*kM;
	    kJ_T=cT*kJ;
	    kappa = inv_logit(logit_kappa);
		  
      // fluxes
	    pA=f*pAm_T*V23;                                   // assimilation rate    
	    pM=kM_T*V;			                                  // somatic maintenance rate
	    pJ=kJ_T*H;			                                  // maturity maintenance rate
	    pC=E_V*((EG*v_T*V23+pM)/(kappa*E_V+EG));          // mobilization rate
	    
      // derivatives
	    dEdt=pA-pC;
	    dVdt=(kappa*pC-pM)/EG ;
	    dHdt=((1-kappa)*pC-pJ)*(H<Hp);
	    dURdt=(((1-kappa)*pC-pJ)*(H>=Hp));

      return {
        dEdt,  
        dVdt,
        dHdt,
        dURdt
      };
    }
}

data {
  int<lower=0> N;          // number of fish
  int<lower=0> n;          // number of replicated measures per fish
  real ts[n];              // times at which the system is observed
  real z0[4];              // initial values of the state variables (assumed to be the same for all the fish)
  
  row_vector[n] length_obs[N];  // simulated length-at-age
  row_vector[n] weight_obs[N];  // simulated weigth-at-age
  row_vector[n] log_fecundity_obs[N]; // simulated (cumulated) fecundity-at-age
  
  int n_fixed;             // number of fixed parameters
  real fixed[n_fixed];     // fixed parameters
}

transformed data {
  real x_r[n_fixed];       // for passing the fixed parameters to the numerical integrator 
  int x_i[0];              // not used but needed for the sintaxis of the numerical integrator
  x_r = fixed;
}

parameters {
  // fish level parameters
  real <lower = 50.0, upper = 250.0> pAm_i[N];
  real <lower = 0.0, upper = 0.05> v_i[N];
  real <lower = 2.0, upper = 5.0> logit_kappa_i[N];
  
  // population level parameters
  real <lower = 50.0, upper = 250.0> pAm;                
  real <lower = 0.0, upper = 0.05> v;
  real <lower = 2.0, upper = 5.0> logit_kappa;
  
  // between-fish variability
  real <lower = 0.0, upper = 30.0> pAm_sd;    
  real <lower = 0.0, upper = 1.0> v_sd;         
  real <lower = 0.0, upper = 3> logit_kappa_sd;
  
  // measurement error
  real <lower = 0, upper = 3> sd_length; 
  real <lower = 0, upper = 10> sd_weight;
  real <lower = 0, upper = 1> sd_fecundity; 
}

transformed parameters {
  real yhat[N,n,4];
  for (i in 1:N){
    yhat[i,,] = integrate_ode_rk45(DEB, z0, 0.0 , ts, {pAm_i[i],v_i[i],logit_kappa_i[i]}, x_r, x_i);
  }
}
    
model {
  // auxiliary fixed parameters
  real rho = fixed[16];
  real wE = fixed[17];
  real muE = fixed[18];
  real cw = fixed[19];
  real kapR = fixed[20];
  real egg_energy = fixed[21];

  // priors parameters
  // population level
  pAm ~ normal(130,100);
  v ~ normal(0.02,0.05);
  logit_kappa ~ normal(3,2);
  
  // between-fish variability
  pAm_sd ~ normal(10,10);
  v_sd ~ normal(0.2,0.5);
  logit_kappa_sd ~ normal(0.5,1);
  
  // fish level
  pAm_i ~ normal(pAm,pAm_sd);
  v_i ~ lognormal(v,v_sd);
  logit_kappa_i ~ normal(logit_kappa,logit_kappa_sd);
  
  // measurement errors
  sd_length ~ normal(0.3,1);
  sd_weight ~ normal(1,2);
  sd_fecundity ~ normal(0.05,0.1);
  
  // Likelihood
  for (i in 1:N){
    length_obs[i,] ~ normal((to_row_vector(cbrt(yhat[i,,2]))/rho),sd_length);
    weight_obs[i,] ~ normal((to_row_vector(yhat[i,,2]) + to_row_vector(yhat[i,,1])*cw*wE/muE),sd_weight);
    log_fecundity_obs[i,] ~ normal((log(1e-20 + to_row_vector(yhat[i,,4])*kapR/egg_energy)),sd_fecundity);
  }
}

generated quantities {

}
"
3 Likes

Hi,
I canā€™t help directly, but have some general comments:

Unfortunately, those two cannot usually be separated. In my experience the most common reason a model is slow is that it is non-identifiable or there is another computational problem (Andrew Gelman calls this the ā€œfolk theoremā€)

You may benefit from within-chain paralellization with reduce_sum or the recently implemented adjoint ODE solver. Iā€™ll tag @wds15 who implemented it and is generally much more experienced than me in speeding up ODE models.

I would definitely start with simulating data as if (some) the ODE variables are directly, but noisily observable and there is no hierarchy and only once you are sure this works well (and fast) add the additional layers of complexity. Or did you already try that?

The discrete step in the ODE system (H<Hp) is almost sure to cause problems as Stan needs continuous derivatives everywhere. Maybe you can use a smooth approximation with soft plus a.k.a log1p_exp?

Best of luck with the model, it looks tough.

1 Like

with 4 states and 4 parameters (if I read that right), then the new adjoint solver will almost for sure not help.

What would help in readability and possibly also with speed is to use the new variadic ODE interface

https://mc-stan.org/users/documentation/case-studies/convert_odes.html

and you should definitely use reduce_sum here!

For all that you need to switch to newer stan versions >2.24 which you can do with cmdstanr, for example.

The restā€¦ identifiability is often a problem in such problems, indeed. The discintinuity pointed out by @martinmodrak is also difficult to handle (maybe replace it with a steep tanh function call?).

2 Likes

Martin, wds15,
Thanks for replying. Iā€™ll be able to get back to this issue in a few days. I will try your suggestions. Just a short comment meanwhile: We have started with very simple models. The last successful check point was a model with only one fish, 3 parameters to be estimated, 3 observed variables with error and 12 observations per variable. It takes 4 minutes. A model with the same settings with 10 fish was successful in terms of accuracy and precision but took one day.
m

1 Like

Do you model the parameters in a hierarchical way for the different fish? If so, then try to put the tauā€˜s of the hierarchical model to be known. Itā€˜s dubious to go from 4 min to one day as you describe. You can calculate the log-lik for the 12 fishes in parallel with reduce_sum, but it sounds ā€žfishyā€œ what you describeā€¦

1 Like

Interesting. I think I can probably help you get this down to a few minutes, especially for exploration, but I donā€™t have the time right now.

I am very sorry but now I have no time at all. I wrote ā€œ4 minutesā€ by heart. I will recheck this figure but, in any case, the difference in computational time between the hierarchical model (several fish; the model I have attached) and a model with one single fish is huge. I promise to give you more details in a few days. Anyway, I AM VERY GRATEFUL with your suggestions and interest.

1 Like

Dear all,
Finally, I have found some time for implementing your suggestions. Again, thank you very much for your interest.

Let me remind where I am:

  1. The biological goal is to understand fish growth at the individual level and to apply this knowledge to aquaculture. The statistical goal is to estimate DEB parameters from real data but I have started with simulated data. I started from scratch. First, I verified that the numerical integration solver I used for simulating data gives exactly the same results than the rk_45 (after using expose_stan_functions of Rstan). After it, I added complexity step-by-step.

  2. The last fully successful checkpoint was the analysis of one fish, 3 observed variables (length, weight and fecundity measured once a year along 20 years lifespan) and 3 DEB parameters to be estimated (all the other DEB parameters was assumed to be known). At this point, accuracy, and precision for the 3 target DEB parameters was very good. Accuracy and precision for the measurement errors of the 3 observed variables was also very good. Thus, at least with these settings, the system seems fully identifiable. Some technical details: 4 chains; No divergent transitions; Rhat close to 1; Thousands of effective samples; Priors wider than posteriors. Computation time was around 5 minutes.

  3. When moving to one fish to several fish, the model was hierarchical (the DEB parameters at the fish level are assumed to be normally distributed around a population mean, with a given sigma). The accuracy and precision with the model in my first post (10 fish) decreased when comparing when analysing one fish only but results could be still acceptable for our objectives. The problem is that computation time sharply increased to 20-30 hours (or even more).

Very long computation times remain irrespective of what I have tried. Some of these trials were according with your suggestions. Note that I have stop computation when it last more than one hour, so I have not assessed accuracy and precision for all the trials below (=> when dealing with more than one fish):

  1. I have moved to cmdstan 2.27.0 (via cmdstanr).
  2. I have implemented reduce_sum at the fish level. I have no idea on how to do it for the other levels. I have played with different grainsize. The optimal grainsize seems to be 1. I am running 4 chains with 20 cores.
  3. I have used the new ODE interface you suggested.
  4. I have changed the discontinuity (H>Hp) by an htan function with a large slope.
  5. I have been playing with different tolerances and max.number when integrating.
  6. I have tried both centered and non-centered parametrizations.
  7. I have tried to fix (=to assume that they are known) between-fish variances (and other parameters).
  8. I have been playing with more or less informative priors at all levels.
  9. Initial values have been set very close to the actual values.
  10. I have been playing with max_treedepth and adapt_delta

At the end you will find one of the model versions for several fish (reduce_sum, non-centered parametrization and very informative priors). This type of run gives (during warm-up) many warnings as: ā€œā€¦Exception: DEB_model_namespace::log_prob: logit_kappa_i[sym1__] is 4.10475, but must be less than or equal to ā€¦.ā€ but they may be related with the many constraints. I do know how to upload the *RData file with the (simulated) input data.

I will be very happy if you could identify some drawback.

Note that I am not very happy for using the trick of adding an small constant (1-e20) for avoiding log(x) in the lognormal_lpdf function, when x is structurally zero (i.e., fecundity of juveniles).

As a suboptimal solution, I am already considering completing a separate analysis for each fish (five minutes per fish) and to compare a posteriori the fish-level DEB parameters.

Thanks again

The code

#-----
# 1: Loading libraries
#-----
remove(list=ls())
library(cmdstanr)

#----- 
# 2: Loading data
#-----
load("input.RData")

#dim(obs)
#[1] 10 20  3
# 10 years, 20 fish and 3 observed variables (length, weight and fecindity)

# the code crashes when fecundity = 0 (juvelines)  
for (i in 1:N){
  temp=which(obs[,i,3]==0)
  if(length(temp)>0){obs[temp,i,3]=1e-20}
}

#----- 
# 3: model
#-----
sink("DEB.stan")
cat(" // first line
functions {
  // Function for reduce_sum
  real partial_sum_lpdf(vector[] y2,         // observations (length, weight and fecundity at age)
                        int start, int end,  // reduce_sum internal parameters
                        vector[] x2,         // state variables (from the numertical integration)
                        real rho,            // paramters for length
                        real sd_length,
                        real cw,             // parameters for weight
                        real wE,
                        real muE,
                        real sd_weight,
                        real kapR,           // parameters for fecundity
                        real egg_energy,
                        real sd_fecundity
                         ) {
    return normal_lpdf(to_vector(y2[,1]) | pow(to_vector(x2[start:end,2]),(1/3.0))/rho , sd_length)
             +
           normal_lpdf(to_vector(y2[,2]) | to_vector(x2[start:end,2]) + cw*wE/muE*to_vector(x2[start:end,1]) , sd_weight)
             +
           lognormal_lpdf(to_vector(y2[,3]) | log(1e-20 + to_vector(x2[start:end,4])*kapR/egg_energy), sd_fecundity)
             ;
  }

  // DEB model   
  vector DEB(
    real t,          // time
    vector y,        // state variables
    // Parameters
    real Kelvin,     // 1
    real Temp_mean,  // 2
    real Temp_amp,   // 3
    real pi2f,       // 4
    real Temp_phi,   // 5
    real TA,         // 6
    real T1,         // 7
    real f,          // 8
    real EG,         // 9
    real kJ,         // 10
    real logit_kappa,  // 11to be estimarted
    real pAm,        // 12 to be estimated 
    real kM,         // 13
    real v,          // 14 to be estimated 
    real Hp          // 15
    ){
      // State variables
      // y[1] Reserve Energy (j)
      // y[2] Structural length (cm)
      // y[3] Maturation Energy (j)
      // y[4] Reproduction Energy (j)
      // Derivatives
      vector[4] dydt;

      // Auxiliary variables
      real Temp;         // Temperature
      real cT;           // Arrhenius temperature correction
      real E_V;          // energy density
      real V23;          // surface
      real pAm_T;        // temperature corrections
      real v_T;
      real kM_T;
      real kJ_T;
      real step_fun;     // steep function for inequality H>Hp
      real kappa;        // inverse logit_kappa
      
      // fluxes (j/day)
      real pA; //assimilation
      real pM; //somatic maintenance
      real pJ; //maturity maintenance
      real pC; //mobilization rate
      
      // temperature correction and auxiliary variables
      Temp = Kelvin + Temp_mean + Temp_amp*sin(pi2f*t + Temp_phi); 
      cT = exp(TA/T1 - TA/Temp);
      //E_V = E/V;
      E_V = y[1]/y[2];
      V23 = pow(y[2],2.0/3.0);
      pAm_T=cT*pAm;
	    v_T=cT*v;
	    kM_T=cT*kM;
	    kJ_T=cT*kJ;
	    kappa = inv_logit(logit_kappa);
		  
      // fluxes
	    pA=f*pAm_T*V23;                                   // assimilation rate    
	    pM=kM_T*y[2];			                                // somatic maintenance rate
	    pJ=kJ_T*y[3];			                                // maturity maintenance rate
	    pC=E_V*((EG*v_T*V23+pM)/(kappa*E_V+EG));          // mobilization rate
	    
	    // steep function for H<Hp (hyperbolic tangent)
	    step_fun = 0.5 + 0.5*tanh(1000*(y[3]-Hp));        

      // derivatives
	    dydt[1]=pA-pC;
	    dydt[2]=(kappa*pC-pM)/EG ;
	    dydt[3]=((1-kappa)*pC-pJ)*(1-step_fun);
	    dydt[4]=((1-kappa)*pC-pJ)*step_fun;

      return dydt;
    }
}

data {
  int<lower=0> N;          // number of fish
  int<lower=0> n;          // number of replicated measures per fish
  real ts[n];              // times at which the system is observed
  vector[4] z0;            // initial values of the state variables (assumed to be known and the same for all fish)
  vector[3] obs[n,N];      // Observations; each vector is [length,weight,fecundity]
  int<lower=1> grainsize;  // parameter for reduce_sum function
  
  // parameters for the integrate function
  real Kelvin;      //1
  real Temp_mean;   //2
  real Temp_amp;    //3
  real pi2f;        //4
  real Temp_phi;    //5
  real TA;          //6
  real T1;          //7
  real f;           //8
  real EG;          //9
  real kJ;          //10
  //real kappa;       //11 to be estimated
  // real pAm;        //12 to be estimated 
  real kM;          //13
  //real v;           //14 to be estimated
  real Hp;          //15
  
  // other parameters of the DEB model (conecting state variables and observations)
  real rho;        //16
  real wE;         //17
  real muE;        //18
  real cw;         //19
  real kapR;       //20
  real egg_energy; //21

}

transformed data {
}

parameters {
  // population level means
  real <lower = 100, upper = 170> pAm;
  real <lower = 0.01, upper = 0.03> v;
  real <lower = 2.4, upper = 3.6> logit_kappa;
  
  // Fish-specific level (non-centered parametrization; z refers to standarized normal)
  vector [N] pAm_z;
  vector [N] v_z;
  vector [N] logit_kappa_z;
  
  // between-fish variability
  real <lower = 5, upper = 20> pAm_sd;
  real <lower = 0, upper = 0.01> v_sd;
  real <lower = 0, upper = 0.5> logit_kappa_sd;
  
  // measurement errors
  real <lower = 0, upper = 1> sd_length; 
  real <lower = 0, upper = 3> sd_weight;
  real <lower = 0, upper = 0.2> sd_fecundity; 
}

transformed parameters {
  // fish specific paramters
  vector <lower = 100, upper = 170> [N] pAm_i;
  vector <lower = 0.01, upper = 0.03> [N] v_i;
  vector  <lower = 2.4, upper = 3.6> [N] logit_kappa_i;
  // fish- and time specific state variables 
  //vector <lower=0> [4] yhat[n,N];
  vector [4] yhat[n,N];
  
  // fish specific paramters (non-centered paramterization)
  pAm_i = pAm + pAm_sd * pAm_z;
  v_i = v + v_sd * v_z;
  logit_kappa_i = logit_kappa + pAm_sd*logit_kappa_z;
   
  // Numerical integration 
  for (i in 1:N){ // 1 to N fish
    yhat[,i] = ode_rk45_tol(DEB, z0, 0.0 , ts,
      // tolerance
      1e-5, 1e-3, 500,
      // parameters
      Kelvin, Temp_mean, Temp_amp, pi2f, Temp_phi, // parameters for the tempetarure forcing function
      TA, T1,                                      // Arrhenius correction
      f,                                           // functional response (food)
      EG,
      kJ,
      logit_kappa_i[i],  // to be estimated (fish specific)
      pAm_i[i],          // to be estimated 
      kM,
      v_i[i],            // to be estimated
      Hp);
  }
}   

model {
  // Population level priors
  pAm ~ normal(130,10);
  v ~ gamma(1,50); //0.0204 quantile(rgamma(10000,1,50),c(0.025,0.5,0.975))
  logit_kappa ~ normal(3,0.5);
  
  // Fish level priors
  pAm_i ~ normal(130,20);
  v_i ~ gamma(1,50); //0.0204 quantile(rgamma(10000,1,50),c(0.025,0.5,0.975))
  logit_kappa_i ~ normal(3,0.5);
  
  // Standarized (non-centered parametrization) fish level priors
  pAm_z ~ std_normal();
  v_z ~ std_normal();
  logit_kappa_z ~ std_normal();
  
  // Between-fish variability priors
  pAm_sd ~ gamma(1,0.07);
  v_sd ~ gamma(1,484);
  logit_kappa_sd ~ gamma(1,3);
  
  // Measurement error priors
  sd_length ~ gamma(1,3);  //0.3 quantile(rgamma(10000,1,3),c(0.025,0.5,0.975))
  sd_weight ~ gamma(1,1);  //1  quantile(rgamma(10000,1,1),c(0.025,0.5,0.975))
  sd_fecundity ~ gamma(1,20); //0.05 quantile(rgamma(10000,1,20),c(0.025,0.5,0.975))

  // Likelihood (via reduce_sum)
  for (i in 1:N){ // 1 to N fish
    target += reduce_sum(partial_sum_lpdf, obs[,i], grainsize, yhat[,i],
                        rho, sd_length, cw, wE, muE, sd_weight, kapR, egg_energy, sd_fecundity);
  }
}

generated quantities {
}

" # end of model code
,fill = TRUE)
sink()

#----- 
# 4: compiling model
#-----
mod = cmdstan_model("DEB.stan")

#----- 
# 5: running
#-----
fit = mod$sample(
  data =list (
    z0 = as.array(z0),
    ts = as.numeric(times.obs),
    obs=obs,
    n=n,
    N=N,
    grainsize=1,
    #paramters
    Kelvin=parms$Kelvin,
    Temp_mean=parms$Temp_mean,
    Temp_amp=parms$Temp_amp,
    pi2f=parms$pi2f,
    Temp_phi=parms$Temp_phi,
    TA=parms$TA,
    T1=parms$T1,
    f=parms$f,
    EG=parms$EG,
    kJ=parms$kJ,
    #kappa=parms$kappa,
    #pAm=parms$pAm,
    kM=parms$kM,
    #v=parms$v,
    Hp=parms$Hp,
    rho=parms$rho,
    wE=parms$wE,
    muE=parms$muE,
    cw=parms$cw,
    kapR=parms$kapR,
    egg_energy=parms$egg_energy
  ),
  #seed = 123,
  parallel_chains = 20,
  chains=4,
  iter_warmup = 1000,
  iter_sampling=1000,
  init = inits,
  max_treedepth=10,
  adapt_delta=0.8
)

1 Like

You need to pack the ode call inside the partial sum function to have any gain from this. You want the ODE solve being parallelā€¦ and grainsize=1 is ok in this case.

For 0 problem: Iā€™d suggest an absolute tolerances of 1E-6 and then add a constant of 1E-4. In your problem 1E-5 is anyway already defined as 0, since itā€™s the absolute tolerance of the ODE solution.

Have you set the hierarchical variances to a known value?

1 Like

wds15,

It sounds clever, but could you give me some idea on how to implement it in practice?

Yes! I have tried it and the computation time was very long too. I have even tried to make constant (= inputed as data) not only between-fish variances but the populations means. Note that the priors on between-fish variances in the code attcahed can be very inforamative becasue they are known (=simulated data) and that they would be less inforamtiva in a real world case.

Essentially you move the code from the transformed parameters into the partial sum function. So the partial sum function will calculate the mean at any given time-point and at the end return the log-lik implied by that. The fact that you can pass any argument to reduce_sum should make this a feasible thing to code.

Once you get that going, then you will actually be benefitting from parallelism and in your case you can use as many CPUs as you have fish and you will still gain in speed. The speed gain should be close to the number of CPU cores you use. 4 cores per chain will make this problem run 4 times faster, because ODEs are so costly to evaluate.

1 Like

Could you provide the fixed parameter values as a JSON file? Otherwise I wonā€™t be able to help you.

Thanks!

1 Like

This stuff can be very hard, so I can provide no guarantees that the model could in the end be made to work - I hope so, but it is just not obvious this is the case.

Unfortunately, a hyperbolic tangent with large slope will likely behave numerically almost the same as the discontinuity (because it is very close to it). If you use a mild slope, would the model behave better? (putting interpretability/reality of the model aside for a while).

All the hard bounds on parameters could be a source of problems as well - are you sure the actual parameter values in your data are far enough from the bounds? (is this simulated data?). In general, those are most likely soft cosntraints - you have a rough idea about the plausible range of the parameter, but there is no physical/mathematical reason why the values could not be outside those ranges. Soft constraints are usually better expressed as a prior, while bounds are reserved to ā€œhardā€ constraints (e.g. positivity), so assuming pAm needs to positive (hard constraint) and should be roughly between 100 and 170 (soft constraint) one could have:

real <lower = 0> pAm;

pAm ~ lognormal(4.9,0.15); //a priori 95% interval 100 - 180

additionally, Stan is not enforcing constraints in the transformed parameters block, it is just checking they hold (so they are basically a debugging tool), so you usually need to make sure (by construction of your transformed params) that the bounds hold. So when you have:

  vector <lower = 100, upper = 170> [N] pAm_i;
  ...
  pAm_i = pAm + pAm_sd * pAm_z;

you are likely to see problems as pAM_i can easily violate the constraints. If pAM_i is constrainted to be positive (I donā€™t know) that a parametrization on the log scale and removal of the bounds is likely a better solution. Hitting the bounds during warmup could easily be the reason your sampling is super slow (the sampler realizes it cannot do big steps without getting rejected because of the bounds so it adapts to the veeeeery small steps).

Hope that helps at least bit.

There are only some file extensions allowed, I didnā€™t realize RData is not among them, so I just added it.

2 Likes

wds15, Funko_Unko, martinmodrak
Thanks for so many inputs. I will try them and give you feedback as soon as possible.

m

Dear all,

According with the suggestions from the forum, I have implemented a new code with (the code is run via cmdstanr):

  • The new ODE interface (ode_45; not integrate_ode_rk45).
  • The reduce_sum function. Slicing was done at the subject (=fish) level (there are several observations from the same fish across its lifespan). Grainsize=1.
  • The ODE call was moved within the partial sum function.
  • Observed zeros for a variable for which observation error is assumed to be log-normal: tolerances (of rk_45_tol) were set at 10-6 and a small constant (10-4) was added to the observed zeros and to their corresponding expected values.
  • The slope of the steep htan function is now not so large (the hyperbolic tangent is used instead of an inequality within the integrate function).
  • All constraints have been removed (excepting a few cases for ensuring positiveness). I misunderstood the role of these constraints. martinmodrak: Thanks for your clarification on that. Instead, informative priors have been set.

Some details of the new run: 1000+1000 iterations; 4 chains, 28 cores; 10 fish; 3 observed variables; 7 observations per variable and fish; estimating 3 ODE parameters per fish (+ 3 means + 3 sd) and 3 measurement errors (=>39 parameters). Convergence seems ok (Rhats close to 1; large effective samples, ā€¦). Posteriors were narrower than priors.

Accuracy and precision of the parameter estimations at the fish level seems to be acceptable.

Speed improves by a factor between 4 to 5, which seems close to be the expecte improvement suggested by wds15. Now, computation time was 4-5 hours, which is not at the minutes scale, but is certainly an improvement when compared with the computation time of the initial model I posted (16-20 hours).

During the 100 first warmup iterations, I have many warnings. One of the most common concerns lognormal_lpdf (scale parameters is negative or nan) when summing loglikelihood (within the partial_sum_lpdf function) but they shouldn`t because the prior of the corresponding standard deviation is gamma distributed. Note that the corresponding observed variable is the one displaying structural zeroes. I understand that they can be ignored, but they may be informative for improving speed.

I have also noticed that priors like v_sd ~ lognormal(-3.89,1) produces warnings (at the warmup) that disappear when using log_v_sd ~ normal(-3.89,1) and v_sd=exp(log_v_sd). Again, I understand that they could
be ignored, but they may be informative for improving speed.

Thus, unless you have some additional suggestion for further reducing computation time, this topic could be closed from my side.

I am adding a full set of (simulated) data and the new code:

The code (the stan model only):

sink("DEB.stan")
cat(" // first line
functions {
  // DEB model (ODE)  
  vector DEB(
    real t,          // time
    vector y,        // state variables
    // ODE Parameters
    real Kelvin,     // 1
    real Temp_mean,  // 2
    real Temp_amp,   // 3
    real pi2f,       // 4
    real Temp_phi,   // 5
    real TA,         // 6
    real T1,         // 7
    real f,          // 8
    real EG,         // 9
    real kJ,         // 10
    real logit_kappa_i,      // 11 to be estimated
    real pAm_i,              // 12 to be estimated
    real kM,         // 13
    real v_i,                // 14 to be estimated
    real Hp          // 15
    ){
      // Derivatives
      // real dEdt;  // Reserve Energy (j)
      // real dVdt;  // Structural length (cm)
      // real dHdt;  // Maturation Energy (j)
      // real dURdt; // Reproduction Energy (j)
      vector[4] dydt;

      // Auxiliary variables
      real Temp;         // Temperature
      real cT;           // Arrhenius temperature correction
      real E_V;          // energy density
      real V23;          // surface
      real pAm_T;        // temperature corrections
      real v_T;
      real kM_T;
      real kJ_T;
      real step_fun;     // step function for inequality H>Hp
      real kappa;
      
      // fluxes (j/day)
      real pA; 
      real pM;
      real pJ;
      real pC;
      
      // temperature correction and auxiliary variables
      Temp = Kelvin + Temp_mean + Temp_amp*sin(pi2f*t + Temp_phi); 
      cT = exp(TA/T1 - TA/Temp);  // temperature correction
      E_V = y[1]/y[2];            // energy density
      V23 = pow(y[2],2.0/3.0);    // surface
      pAm_T=cT*pAm_i;             // tcorrected parameters 
	    v_T=cT*v_i;
	    kM_T=cT*kM;
	    kJ_T=cT*kJ;
	    kappa = inv_logit(logit_kappa_i);  // logit_kappa to kappa
		  
      // fluxes
	    pA=f*pAm_T*V23;                                   // assimilation rate    
	    pM=kM_T*y[2];			                                // somatic maintenance rate
	    pJ=kJ_T*y[3];			                                // maturity maintenance rate
	    pC=E_V*((EG*v_T*V23+pM)/(kappa*E_V+EG));          // mobilization rate
 
	    // step function for H<Hp
	    step_fun = 0.5 + 0.5*tanh(.1*(y[3]-Hp));          // hyperbolic tangent

      // derivatives
	    dydt[1]=pA-pC;
	    dydt[2]=(kappa*pC-pM)/EG ;
	    dydt[3]=((1-kappa)*pC-pJ)*(1-step_fun);           //(y[3]<Hp);//
	    dydt[4]=((1-kappa)*pC-pJ)*step_fun;               //(y[3]>=Hp);//

      return dydt;
  }
  // Function for reduce_sum (multicore optimization)
  real partial_sum_lpdf(vector[] obs,         // observations
                        int start, int end,   // reduce_sum internal paramters
                        //vector[] y,         // state variables (from the numerical integration  od ODE)
                        vector y0,            // initial values 
                        real t0,              // initial time
                        real [] ts,           // observation times
                        int n,                // number of observations per fish
                        // parameters linking state variables with observations
                        // length
                        real rho,
                        real sd_length,       // to be estimated 
                        // weighht
                        real cw,              
                        real wE,
                        real muE,
                        real sd_weight,       // to be estimated
                        // fecundity
                        real kapR,
                        real egg_energy,
                        real sd_fecundity,    // to be estimated
                        // ODE fixed paramters 
                        real Kelvin,
                        real Temp_mean,
                        real Temp_amp,
                        real pi2f,
                        real Temp_phi,
                        real TA,
                        real T1,
                        real f,
                        real EG,
                        real kJ,
                        real kM,
                        //real v,
                        real Hp,
                        // ODE parameters to be estimated
                        real pAm_i,
                        real logit_kappa_i,
                        real v_i
                        
                         ) {
    // numerical integration
    vector[4] y[n] = ode_rk45_tol(DEB, y0, t0 , ts,
      1e-6,1e-6,10000,
      // ODE parameters
      Kelvin, Temp_mean, Temp_amp, pi2f, Temp_phi, TA, T1, f, EG, kJ,
      logit_kappa_i,   // to be estimated 
      pAm_i,           // to be estimated  
      kM,
      v_i,             // to be estimated 
      Hp);                     
    
    // log-likelihood                     
    return normal_lpdf(to_vector(obs[,1]) | pow(to_vector(y[start:end,2]),(1/3.0))/rho , sd_length)
             +
           normal_lpdf(to_vector(obs[,2]) | to_vector(y[start:end,2]) + cw*wE/muE*to_vector(y[start:end,1]) , sd_weight)
             +
           lognormal_lpdf(to_vector(obs[,3]) | log(1e-4 + to_vector(y[start:end,4])*kapR/egg_energy), sd_fecundity)
             ;
  }
}

data {
  int<lower=0> N;          // number of fish
  int<lower=0> n;          // number of replicated measures per fish
  real t0;                 // initial time (0.0)
  real ts[n];              // times at which the system is observed
  vector[4] y0;            // initial values of the state variables
  vector[3] obs[n,N];      // Observations (length, weight and fecundity)
  int<lower=1> grainsize;
  
  // fixed parameters for the ODE function
  real Kelvin;      //1
  real Temp_mean;   //2
  real Temp_amp;    //3
  real pi2f;        //4
  real Temp_phi;    //5
  real TA;          //6
  real T1;          //7
  real f;           //8
  real EG;          //9
  real kJ;          //10
  //real kappa;       //11 to be estimated
  // real pAm;        //12 to be estimated 
  real kM;          //13
  //real v;           //14 to be estimated
  real Hp;          //15
  
  // fixed parameters for linking state variables with observations
  real rho;        //16
  real wE;         //17
  real muE;        //18
  real cw;         //19
  real kapR;       //20
  real egg_energy; //21
}

transformed data {
}

parameters {
  real pAm;                    // fish level mean
  real pAm_z[N];               // standrad normal for non-centered parametrization
  real log_pAm_sd;             // between fish variability
  real log_v;                  // the same for v
  real v_z[N];
  real log_v_sd;
  real logit_kappa;            // the same for kappa
  real logit_kappa_z[N];
  real log_logit_kappa_sd;
  real sd_length;              // measurement errors
  real sd_weight;
  real sd_fecundity; 
}

transformed parameters {
  real <lower=0> pAm_i[N];
  real <lower=0> v_i[N];
  real logit_kappa_i[N];
  real pAm_sd;
  real v;
  real v_sd;
  real logit_kappa_sd;
  
  // non-centered parametrization
  pAm_sd=exp(log_pAm_sd); 
  v=exp(log_v);
  v_sd = exp(log_v_sd);
  logit_kappa_sd = exp(log_logit_kappa_sd);
  for (i in 1:N){
    pAm_i[i] = pAm + pAm_sd*pAm_z[i];
    v_i[i] = v + v_sd*v_z[i];
    logit_kappa_i[i] = logit_kappa + logit_kappa_sd*logit_kappa_z[i];
  }
}   

model {
  // priors
  pAm ~ lognormal(4.91,0.1); 
  log_pAm_sd ~ normal(2.63,1);
  pAm_z ~ std_normal();
  
  log_v ~ normal(-3.89,1);
  log_v_sd ~ normal(-6.18,1);
  v_z ~ std_normal();
    
  logit_kappa ~ normal(3,0.5); 
  log_logit_kappa_sd ~ normal(-1.16,1);
  logit_kappa_z ~ std_normal();
    
  sd_length ~ gamma(1,3);
  sd_weight ~ gamma(1,1);
  sd_fecundity ~ gamma(1,20);

  // Likelihood
  for (i in 1:N){
    target += reduce_sum(partial_sum_lpdf, obs[,i], grainsize,
                        y0, t0,ts,n,
                        rho, sd_length, cw, wE, muE, sd_weight, kapR, egg_energy, sd_fecundity,
                        Kelvin, Temp_mean, Temp_amp, pi2f, Temp_phi,
                        TA, T1, f, EG, kJ, kM, Hp,
                        pAm_i[i],logit_kappa_i[i],v_i[i]
                        );
  }
}

generated quantities {
}

" # end of model code
,fill = TRUE)
sink()

The data:
input.RData (4.9 KB)

Funko_unko: I am very grateful with your interest. I have zero expertise in json files but apparently, they can be produced from R (library jsonlite; function write_json). I am adding a file with the parameter values.
tmp.txt (545 Bytes)

m

1 Like

Thanks for the data!

(I might have accidentally deselected the ā€œSolutionā€ of this thread, but Iā€™m not sure whether I previously accidentally selected it in the first place. Sorry!)

I selected it. On second thought, though, I should leave it to you people to decide whether the problem is ā€œsolvedā€.

Thanks for the clarification!

I think @palmer is satisfied :)

1 Like

I am on cell pho, so canā€™t really do any tests, but I agree that the warnings might suggest additional improvements. Note that even for parameters with lower bound of zero, one can observe actual zeroes due to numerical inaccuracies (Stan models lower bounded params via their logarithm internally, so the exp call in the transformation may result in sn exact zero)

Interesting, not sure what exactly is happening here, but in general different functions in the Stan math library have received different amounts of scrutiny for numerical problems, so it is not implausible that an extreme case that is handled properly in normal_lpdf is not handled well in lognormal_lpdf (or the other way around)

I still find itt likely that that the extreme slowdown is because of some issue with the model. How does apairs plot for the global params + params for a single fish look like?

Some additional ā€œlines of attackā€:

  • use print to see intermediate results in the computation and try to better understand why you see the warnings
  • check for prior x data conflict: is the posterior roughly consistent with the prior? If not, what happens if you widen the prior contain the posterior?( this is potentially problematic way to choose prior, but good as a debugging tool)

This looks highly suspicious. If there are zeroes, the error is not log normal. What exactly do the data represent? If it means "lower than detection limit ", then maybe treat those as censored? (See Stan user guide).Changing zeroes to small numbers is sometimes good approximation for censoring, but all cases Iā€™ve seen this used the replacements were not so small. Or maybe you can trace the value to a discrete observable (e.g. egg count)?

2 Likes

Perhaps this is just due to the difference in initialization?

1 Like