Chains getting stuck/not mixing issues

Hi, am fitting and ODE model (Toggle Switch) with Rstan and I am having issues with the performance of some chains, since it seems that there is always one or two chains (I am running 4) that get stuck at the initial value and do not progress or mix at all.
The model that I am using is:


functions{
  // Function containing the ODE to be used for the inference
  real[] Toogle_one(real t, real[] y, real[] p, real[] x_r, int[] x_i){
    // Inputs
    real u_IPTG = x_r[1];
    real u_aTc = x_r[2];
    // Parameters
    real k_in_IPTG = p[1];
    real k_out_IPTG = p[2];
    real k_in_aTc = p[3];
    real k_out_aTc = p[4];
    real k_L_pm0 = p[5];
    real k_L_pm = p[6];
    real theta_T = p[7];
    real theta_aTc = p[8];
    real n_aTc = p[9];
    real n_T = p[10];
    real k_T_pm0 = p[11];
    real k_T_pm = p[12];
    real theta_L = p[13];
    real theta_IPTG = p[14];
    real n_IPTG = p[15];
    real n_L = p[16];
    
    //Equations
    real dInd_dt[4];
    
    if (x_r[1]>y[1]){
      dInd_dt[1]=k_in_IPTG*(x_r[1]-y[1]);
    }
    else{
      dInd_dt[1]=k_out_IPTG*(x_r[1]-y[1]);
    }
    
    if (x_r[2]>y[2]){
      dInd_dt[2]=k_in_aTc*(x_r[2]-y[2]);
    }
    else{
      dInd_dt[2]=k_out_aTc*(x_r[2]-y[2]);
    }

    dInd_dt[3] = ((1/0.1386)*(k_L_pm0+(k_L_pm/(1+(y[4]/theta_T*1/(1+(y[2]/theta_aTc)^n_aTc))^n_T))))-0.0165*y[3];
    dInd_dt[4] = ((1/0.1386)*(k_T_pm0+(k_T_pm/(1+(y[3]/theta_L*1/(1+(y[1]/theta_IPTG)^n_IPTG))^n_L))))-0.0165*y[4];
    
    // RESULTS
    return dInd_dt;
  }
  
  // Function type vector containing the equations where the root needs to be calculated for the steady states
  vector SteadyState(vector init, vector p, real[] x_r, int[] x_i){
    vector[2] alpha;
    // Parameters
    real k_in_IPTG = p[1];
    real k_out_IPTG = p[2];
    real k_in_aTc = p[3];
    real k_out_aTc = p[4];
    real k_L_pm0 = p[5];
    real k_L_pm = p[6];
    real theta_T = p[7];
    real theta_aTc = p[8];
    real n_aTc = p[9];
    real n_T = p[10];
    real k_T_pm0 = p[11];
    real k_T_pm = p[12];
    real theta_L = p[13];
    real theta_IPTG = p[14];
    real n_IPTG = p[15];
    real n_L = p[16];
    // Equations
    alpha[1] = ((1/0.1386)*(k_L_pm0+(k_L_pm/(1+(init[2]/theta_T*1/(1+(x_r[2]/theta_aTc)^n_aTc))^n_T))))/0.0165;
    
    alpha[2] = ((1/0.1386)*(k_T_pm0+(k_T_pm/(1+(init[1]/theta_L*1/(1+(x_r[1]/theta_IPTG)^n_IPTG))^n_L))))/0.0165;
    // Results
    return alpha;
  }
  
}

data {
    // Observables
    int m; // Total number of data series
    int stslm; // Maximum number of rows for all the observable matrices
    int stsl[1,m]; // Number of elements at each time series for each series m
    int sts[stslm,m]; // Sampling times for each series m
    real GFPmean[stslm,m]; // estimated observables for tetR+GFP at each sampling time
    real RFPmean[stslm,m]; // estimated observables for LacI+RFP at each sampling time
    real GFPstd[stslm,m]; // standard error for tetR+GFP at each sampling time
    real RFPstd[stslm,m]; // standard error for LacI+RFP at each sampling time
    
    // Inputs
    int elm; // Number of rows in the matrices for IPTG and aTc, half for the inputs and -1 for the total number of events
    int tml; // Maximum length of the rows for the sampling times matrix
    int Nsp[1,m]; // Number of event switching points (including final time) times for each series m
    real ts[tml, m]; // Time series for each serie m
    int tsl[1,m]; // Length of sampling time series per serie m
    real preIPTG[1,m]; // Values of inputs for each serie m for the ON incubation 
    real preaTc[1,m];
    real IPTG[elm,m]; // Values of inputs at each event for each serie m
    real aTc[elm,m];
    real inputs[(elm*2),m]; // Input values for each event ordered as IPTG, aTc, IPTG, aTc, ...
    int evnT[(elm+1),m]; // Event change time points for each serie m
    
    // Over night incubation times
    int tonil;
    real toni[tonil];
    // real theta[16];
}

transformed data {
  int nParms = 16; // Number of parameters of the model
  int Neq = 4; // Total number of equations of the model
  int x_i[0]; // Empty x_i object (needs to be deffined)
  real x_r[(elm*2),m]=inputs; // Input values for each event ordered as IPTG, aTc, IPTG, aTc, ...
  real ivss[Neq-2,m]; // Initial experimental values for the calculation of the steady state ordered as LacI+RFP, TetR+GFP
  real pre[2,m]; // Input values during the ON incubation ordered as IPTG, aTc
  
  for(i in 1:m){
    ivss[1,i] = RFPmean[1,i];
    ivss[2,i] = GFPmean[1,i];
    pre[1,i] = preIPTG[1,i];
    pre[2,i] = preaTc[1,i];
   };
}

parameters {
    // Parameters to be infered in the model
    
    real k_in_IPTG_raw;
    real k_out_IPTG_raw;
    real k_in_aTc_raw;
    real k_out_aTc_raw;
    real k_L_pm0_raw;
    real k_L_pm_raw;
    real theta_T_raw;
    real theta_aTc_raw;
    real n_aTc_raw;
    real n_T_raw;
    real k_T_pm0_raw;
    real k_T_pm_raw;
    real theta_L_raw;
    real theta_IPTG_raw;
    real n_IPTG_raw;
    real n_L_raw;
}

transformed parameters {
  // Transformation to the true values from non-centred parameterisation
  real theta[nParms];

  theta[1] = exp(((k_in_IPTG_raw)*(1.15129254649702))+(-3.2188758248682));
  theta[2] = exp(((k_out_IPTG_raw)*(1.15129254649702))+(-3.2188758248682));
  theta[3] = exp(((k_in_aTc_raw)*(1.15129254649702))+(-2.30258509299405));
  theta[4] = exp(((k_out_aTc_raw)*(1.15129254649702))+(-2.30258509299405));
  
  theta[5] = exp(((k_L_pm0_raw)*(1.15129254649702))+(-3.50655789731998));
  theta[6] = exp(((k_L_pm_raw)*(1.15129254649702))+(2.30258509299405));
  theta[7] = exp(((theta_T_raw)*(1.15129254649702))+(3.40119738166216));
  theta[8] = exp(((theta_aTc_raw)*(1.15129254649702))+(2.30258509299405));
  theta[9] = (((n_aTc_raw)*(1.25))+(2.5));
  theta[10] = (((n_T_raw)*(1.25))+(2.5));
  theta[11] = exp((((k_T_pm0_raw))*(1.15129254649702))-2.30258509299405);
  theta[12] = exp(((k_T_pm_raw)*(1.15129254649702))+(2.22044604925031e-16));
  theta[13] = exp(((theta_L_raw)*(1.15129254649702))+3.40119738166216);
  theta[14] = exp(((theta_IPTG_raw)*(1.15129254649702))-2.30258509299405);
  theta[15] = ((n_IPTG_raw)*(1.25))+2.5;
  theta[16] = ((n_L_raw)*(1.25))+2.5;
}

model {
  
  // Intermediate parameters
  int i; // Increasing index for the inputs
  vector[2] ing; // Vector that will include the solutio of the algebraic solution for the steady state of the model
  real ssv[tonil,Neq]; // Real that will include the solution of the ODE for the ON incubation (24h)
  real Y0[Neq,m]; // Initial values for the ODEs variables at the first event
  
  // Priors definition (test)
  k_in_IPTG_raw ~ normal(0,1);
  k_out_IPTG_raw ~ normal(0,1);
  k_in_aTc_raw ~ normal(0,1);
  k_out_aTc_raw ~ normal(0,1);
  k_L_pm0_raw ~ normal(0,1);
  k_L_pm_raw ~ normal(0,1);
  theta_T_raw ~ normal(0,1);
  theta_aTc_raw ~ normal(0,1);
  n_aTc_raw ~ normal(0,1);
  n_T_raw ~ normal(0,1);
  k_T_pm0_raw ~ normal(0,1);
  k_T_pm_raw ~ normal(0,1);
  theta_L_raw ~ normal(0,1);
  theta_IPTG_raw ~ normal(0,1);
  n_IPTG_raw ~ normal(0,1);
  n_L_raw ~ normal(0,1);
  
  // Likelihood
  for (j in 1:m){
    real ivst[Neq]; // Initial value of the states 
    real y_hat[(tsl[1,j]),Neq];
    // Calculation of initial guesses
    ing = SteadyState(to_vector(ivss[1:2,j]), to_vector(theta), pre[1:2,j], x_i); // Calculation of initial guesses for steady state
    Y0[1,j] = preIPTG[1,j];
    Y0[2,j] = preaTc[1,j];
    Y0[3,j] = ing[1];
    Y0[4,j] = ing[2];
    ssv = integrate_ode_bdf(Toogle_one, Y0[,j],0,toni,theta,pre[1:2,j], x_i, 1e-9, 1e-9, 1e7); // ON incubation calculation for the steady state
    
    Y0[,j] = ssv[tonil];
    i = 1;
    
    // Loop (over the number of events) to solve the ODE for each event stopping the solver and add them to the final object y_hat
    for (q in 1:Nsp[1,j]-1){
      
      int itp = evnT[q,j];  // Initial time points of each event
      int lts = num_elements(ts[(evnT[q,j]+1):(evnT[q+1,j]+1),j]);  // Length of the time series for each event
      real part1[lts,Neq]; // Temporary object that will include the solution of the ODE for each event at each loop
      // Calculation of the solution for the ODEs where for events that are not the firt one the time series starts one minute before the original point of the time serie overlaping with the last point of the previous event with same state values at the time
      if (q == 1){
        ivst = Y0[,j];
        part1 = integrate_ode_bdf(Toogle_one,ivst,itp,ts[(evnT[q,j]+1):(evnT[q+1,j]+1),j],theta,to_array_1d(inputs[i:(i+1),j]), x_i, 1e-9, 1e-9, 1e7);
      }
      else{
        part1 = integrate_ode_bdf(Toogle_one, ivst,(itp-1e-7),ts[(evnT[q,j]+1):(evnT[q+1,j]+1),j],theta,to_array_1d(inputs[i:(i+1),j]), x_i, 1e-9, 1e-9, 1e7);
      }

      // Modification of the initial state values for the next event
      ivst = part1[lts];
      // Increase index for inputs
      i=i+2;
      
      // Introduction of the result of part1 into the object y_hat
      for (y in (itp+1):(itp+lts)){        
        y_hat[(y),]=(part1)[(y-itp),];        
      };      
    };

    // Likelihood at each sampling time
    for (t in 1:stsl[1,j]){
      RFPmean[t,j] ~ normal(y_hat[(sts[t,j]+1),3],RFPstd[t,j]);
      GFPmean[t,j] ~ normal(y_hat[(sts[t,j]+1),4],GFPstd[t,j]);      
    }  
  };
}

Right now I am working with pseudo-data generated from a set of parameters to validate the model and see if it succeeds to recover the “true parameter values” during an inference before scaling up to experimental data, but I am having issues with the validation. The issue that I am having on every inference is the presence of one or two chains getting stuck at the initial value and don’t mixing with the rest as seen in the image:



There is no presence of iterations with divergences usually, and if there is some they are on the chain that behaves poorly. However I have noticed that for the chains that mix, almost all the iterations do hit the max_treedepth, but the flat chain does not. That made me think that the problem might come from the step size adapted by NUTS, which is really low and constant over all the iterations, but I do not understand why r if there is a problem during the adaptation process due to my model. I also thought that it might have been a problem with the acceptance rate but it is not the case since modifying it has not improved results, neither the initial step size. I do not think that it is a case of multimodality either since samples of the parameters drawn from the flat chain do not reproduce the pseudo-data either. As for the priors used, these are on a centred parameterisation to ease the sampling for NUTS. The true parameter values however, cover a range of 0.1 and 10 the “true” parameter value, which I hope it is not the cause of the problem since once working with real data this is a reasonable range for them (except 4 parameters that cover a range from 0 to 5).

At the moment I do not know what else could it be as for my knowledge in stan (I am quite new in all this), so if anyone has experienced something similar or might have any idea I would really appreciate it.
Thanks,
David

Is the exp in there to make the parameter positive? This is gonna make your parameters tough to interpret.

The lingo might be a bit weird to get used to. Centered is:

x \sim \mathcal{N}(\mu, \sigma)

Non-centered is:

z \sim \mathcal{N}(0, 1)
x = z * \sigma + \mu

So I think you’re doing non-centered parameterizations. I’d stay away from them unless you have divergences and find out you need them. They can make the code hard to read.

I think you’re doing the right thing. Run prior predictives to make sure your model isn’t making insane predictions. Run generated data fits till you get a model that is reliable as possible. The routine is to tighten your priors/constraints until the solution of the ODE doesn’t go off into crazy-land, and you’re getting reliable fits.

Can you take values from the stuck chain and run them in an external ODE solver to see what is happening?

The exp is in there to deal with the fact that I am using a logarithmic range from the “true” parameter values (from 0.1 to 10 times the value) so the transformed value does not get too close to the left boundary since I observed some problems with it.

And true, I meant to say Non-centred, sorry my bad. And indeed, I had to introduce a non-centred parameterisation because otherwise I had problems with divergences (always more than 50% of the iterations presented).

I did this and the results depend on the stuck chain, but mostly it gives me solutions for the ODE equations that do not represent the data, either as in value/level, as in speed of increases/decreases or both. Until now I still haven’t seen a set of parameters from these chains that can represent the data or almost. That is why I thought that it wasn’t a multimodality issue (hopefully).

Thank you so much for your help Ben!

Hmm, it’d be nice to parameterize this on the natural scale so you don’t have all these transforms here. But if those non-centered parameterizations fix a lot of divergences then I guess it’s just tough potatoes.

You can ask Stan to try harder when it adapts. Give it a longer warmup, and set adapt_delta = 0.99 or something (this’ll be in the Rstan manual). It’s the thing to try if all else fails.

Yes, we started on the natural scale but due to such high number of divergences I tried the non-centred parameterizations and things seemed to get quite better.

I have been modifying adapt_delta as 0.9, 0.95, 0.99 and 0.999 but things did not seem to improve at all, but I did not tried to increase the warmup. I am going to try both together and have another look at the Rstan manual to see if there could be any other adaptation setting that could be helpful to modify.

Thanks!