One chain not moving nonlinear mixed effect model (inits value issue)

Hello,

I’m trying to model individual viral dynamics using a simple hierarchical model (piecewise linear). What’s interesting about my data is that the time of symptom onset is given and we want to use this information to determine the start of infection by estimating the incubation period (Tincub). The virus is quantified in the form of a proxy, the number of PCR cycles (cycle threshold, Ct). The lower the Ct, the fewer RNA amplification cycles are needed to detect and quantify the virus, and therefore the more virus-rich the sample. To simplify the modelling, I model the delta compared to the quantity of virus at infection (equal to 50 Ct).

Before running this model on my real data, I simulated data very close to my real data using the model presented below. However, I constantly encountered a problem when launching the dataset. For most of the initial values, some chains have enormous difficulty moving and remain flat, taking 20x longer to run compared to chains that move and explore the space well. I don’t understand this instability problem given that the initial values of my 4 chains are randomly sampled from my prior distributions.

Have you ever encountered a similar problem, and do you know where it could be coming from?

Here’s the stan model, I’ve tried to comment well, but don’t hesitate to ask for clarification!

Thank you in advance,

Maxime

functions {
      // Ctfun returns dCt, the delta Ct below Vinf:
    real CtVinffun(real t, real tinf, int lod, int Vinf, real tp, real Tp, real Tc, real Vp){
      
      // Viral load rises before peak of viral load: 
        if (t >= tinf && t <= tp)
          return(Vp*(t-tinf)/Tp); // Tp = tp-tinf
          
      // Viral load falls after peak of viral load: 
        else if (t > tp)
           return(Vp+(Vinf-lod-Vp)*(t-tp)/Tc);
          
        else //(t <= tinf)
          return(0);
        
    }
}


data{
  int<lower=0> N; // Number of concatenated data points
  int<lower=0> n_id; // Number of individuals 
  int lod; // Limit of detection of Ct
  int Vinf; // Ct delta from lod at infection (Ct=50) 
  int<lower=0> id[N]; // Vector marking which datum belongs to which id
  real t[N];  // Vector marking the time for each data point 
  real<lower=10, upper=Vinf> y[N];  // Concatenated data vector 
  real tSS[n_id]; // Vector of individual time of symptom onset 
  int nb_random; // Total number of parameters with a random effect
}

transformed data {
  
  real log_is_lod[N]; // log of "is the data is at the LOD value ? : 0 OR 1"
  real log_is_obs[N]; // log of " if the data is below the LOD value ? : 0 OR 1"
  
  for(i in 1:N){
    if(y[i]==10){
      log_is_obs[i]=log(0);
      log_is_lod[i]=log(1);
    } 
    else{ 
      log_is_obs[i]=log(1);
      log_is_lod[i]=log(0);
    }
  }
}


parameters{
  
  // parameters of Vp 
  real<lower=10, upper=50> mu_Vp;
  
  // parameters of proliferation phase
  real<lower=0> mu_Tp;
  
  // parameters of clearance phase
  real<lower=0> mu_Tc;
  
  // parameters incubation period
  real<lower=0, upper=14> mu_Tincub;
  
  real<lower=0> sigma;
  
  matrix[n_id,nb_random] eta_tilde; // random effect for Vp, Tp, Tc, and Tincub
  
  real<lower=0> eta_omega[nb_random];
}


transformed parameters{
  
  real<lower=0, upper=14> Tincub[n_id]; // Incubation period cannot exceed 14 days
  
  matrix[n_id,nb_random] eta; // matrix of random effects for Vp, Tp, Tc, and Tincub
  
  real<lower=10, upper=50> Vp[n_id];
  real<lower=0> Tp[n_id];
  real<lower=0> Tc[n_id];
  real tp[n_id];
  real tinf[n_id];
  
  real<upper=50> Ct[N];
  
  real num_arg[N,2];
  
  for(j in 1:nb_random){
    eta[,j] = eta_tilde[,j]*eta_omega[j];
  }
  
  for(i in 1:n_id){
    
    Vp[i] = exp(log(mu_Vp) + eta[i,1]);
    
    Tp[i] = exp(log(mu_Tp) + eta[i,2]);
    
    Tc[i] = exp(log(mu_Tc) + eta[i,3]);
    
    Tincub[i] = exp(log(mu_Tincub) + eta[i,4]);
    
    tinf[i] = tSS[i] - Tincub[i];

    tp[i] = tinf[i] + Tp[i];
    
  }
  
  for(i in 1:N){
    
    Ct[i] = CtVinffun(t[i], tinf[id[i]], lod, Vinf, tp[id[i]], Tp[id[i]], Tc[id[i]], Vp[id[i]]);
    
    if (t[i] < tinf[id[i]]){
      num_arg[i,1] = log_is_obs[i] + log(0.0002); // likelihood that Ctobs value < LOD (false positive PCR test)
      num_arg[i,2] = log_is_lod[i] + log(0.9998); // likelihood that Ctobs value == LOD (true negative PCR test)
    } else{
    num_arg[i,1] = log_is_obs[i] + normal_lpdf(y[i] | Ct[i], sigma); // likelihood that Ctobs value < LOD (positive PCR test)
    num_arg[i,2] = log_is_lod[i] + normal_lcdf(10 | Ct[i], sigma); // likelihood that Ctobs value == LOD (negative PCR test)
      
    }
    // If the observation is not at the lod value, num_arg[i,2] is equal to -inf and will not be taken in the likelihood : target += log_sum_exp(num_arg[i]);
    // Because log(exp(log(0)) + exp(num_arg[i,1])) = num_arg[i,1]
  }
}


model{
  // Priors //
  
  mu_Vp ~ normal(25, 5) T[10,50]; // hierarchical mean (mu)
  mu_Tp ~ normal(6, 2) T[0,]; // hierarchical mean (mu)
  mu_Tc ~ normal(15, 5) T[0,]; // hierarchical mean (mu)
  mu_Tincub ~ normal(5, 1); // hierarchical mean (mu)
  
  sigma ~ normal(0, 1) T[0,]; // mesurement error
  
  to_vector(eta_tilde) ~ normal(0,1);
  
  eta_omega ~ normal(0,1) T[0,]; // variance of random effect
  
  // Likelihood (looped on each observation) //
  for(i in 1:N){
    target += log_sum_exp(num_arg[i]);
  }
}

generated quantities{
  
  vector[N] log_lik;
  
  for (i in 1:N){
    log_lik[i] = log_sum_exp(num_arg[i]);
  }
  
  real tclear[n_id];
  
  for (j in 1:n_id){
    
    tclear[j] = tp[j] + Tc[j];
  }
}
1 Like

Hello,

I don’t have a ready-made solution but I have to points of inquiry:

  1. You seem to be saying that the problem comes from the initial values. These values are different for each of your chains. If you take the initial values that led to a flat chain, and impose them on a new run for your 4 chains. Do you end up with 4 flat chains?

  2. Have you tried loosing your priors? Your residual error has a pretty restricted prior (sigma ~ normal(0, 1) T[0,]; // measurement error). Maybe that’s important enough for your values, but it mights help the convergence to use a larger prior for your residual error.

And out of curiosity, what do your “T[0,];” or “T[10,50];” mean after your priors?

Those are truncated distributions. But I believe these are redundant since Stan already applies the necessary transformations based on the parameters’ declared bounds in the parameters block.
So for example, since sigma is correctly declared to be strictly positive with lower=0, Stan already truncates the normal(0, 1) prior at zero (and automatically applies a Jacobian adjustment for the transformation induced by this constraint), so the T[0, ] should be redundant.
You only really need the explicit truncation if you’d like to apply truncation that differs from the constraints in the parameters block.