One chain not moving nonlinear mixed effect model (inits value issue)

Hello,

I’m trying to model individual viral dynamics using a simple hierarchical model (piecewise linear). What’s interesting about my data is that the time of symptom onset is given and we want to use this information to determine the start of infection by estimating the incubation period (Tincub). The virus is quantified in the form of a proxy, the number of PCR cycles (cycle threshold, Ct). The lower the Ct, the fewer RNA amplification cycles are needed to detect and quantify the virus, and therefore the more virus-rich the sample. To simplify the modelling, I model the delta compared to the quantity of virus at infection (equal to 50 Ct).

Before running this model on my real data, I simulated data very close to my real data using the model presented below. However, I constantly encountered a problem when launching the dataset. For most of the initial values, some chains have enormous difficulty moving and remain flat, taking 20x longer to run compared to chains that move and explore the space well. I don’t understand this instability problem given that the initial values of my 4 chains are randomly sampled from my prior distributions.

Have you ever encountered a similar problem, and do you know where it could be coming from?

Here’s the stan model, I’ve tried to comment well, but don’t hesitate to ask for clarification!

Thank you in advance,

Maxime

functions {
      // Ctfun returns dCt, the delta Ct below Vinf:
    real CtVinffun(real t, real tinf, int lod, int Vinf, real tp, real Tp, real Tc, real Vp){
      
      // Viral load rises before peak of viral load: 
        if (t >= tinf && t <= tp)
          return(Vp*(t-tinf)/Tp); // Tp = tp-tinf
          
      // Viral load falls after peak of viral load: 
        else if (t > tp)
           return(Vp+(Vinf-lod-Vp)*(t-tp)/Tc);
          
        else //(t <= tinf)
          return(0);
        
    }
}


data{
  int<lower=0> N; // Number of concatenated data points
  int<lower=0> n_id; // Number of individuals 
  int lod; // Limit of detection of Ct
  int Vinf; // Ct delta from lod at infection (Ct=50) 
  int<lower=0> id[N]; // Vector marking which datum belongs to which id
  real t[N];  // Vector marking the time for each data point 
  real<lower=10, upper=Vinf> y[N];  // Concatenated data vector 
  real tSS[n_id]; // Vector of individual time of symptom onset 
  int nb_random; // Total number of parameters with a random effect
}

transformed data {
  
  real log_is_lod[N]; // log of "is the data is at the LOD value ? : 0 OR 1"
  real log_is_obs[N]; // log of " if the data is below the LOD value ? : 0 OR 1"
  
  for(i in 1:N){
    if(y[i]==10){
      log_is_obs[i]=log(0);
      log_is_lod[i]=log(1);
    } 
    else{ 
      log_is_obs[i]=log(1);
      log_is_lod[i]=log(0);
    }
  }
}


parameters{
  
  // parameters of Vp 
  real<lower=10, upper=50> mu_Vp;
  
  // parameters of proliferation phase
  real<lower=0> mu_Tp;
  
  // parameters of clearance phase
  real<lower=0> mu_Tc;
  
  // parameters incubation period
  real<lower=0, upper=14> mu_Tincub;
  
  real<lower=0> sigma;
  
  matrix[n_id,nb_random] eta_tilde; // random effect for Vp, Tp, Tc, and Tincub
  
  real<lower=0> eta_omega[nb_random];
}


transformed parameters{
  
  real<lower=0, upper=14> Tincub[n_id]; // Incubation period cannot exceed 14 days
  
  matrix[n_id,nb_random] eta; // matrix of random effects for Vp, Tp, Tc, and Tincub
  
  real<lower=10, upper=50> Vp[n_id];
  real<lower=0> Tp[n_id];
  real<lower=0> Tc[n_id];
  real tp[n_id];
  real tinf[n_id];
  
  real<upper=50> Ct[N];
  
  real num_arg[N,2];
  
  for(j in 1:nb_random){
    eta[,j] = eta_tilde[,j]*eta_omega[j];
  }
  
  for(i in 1:n_id){
    
    Vp[i] = exp(log(mu_Vp) + eta[i,1]);
    
    Tp[i] = exp(log(mu_Tp) + eta[i,2]);
    
    Tc[i] = exp(log(mu_Tc) + eta[i,3]);
    
    Tincub[i] = exp(log(mu_Tincub) + eta[i,4]);
    
    tinf[i] = tSS[i] - Tincub[i];

    tp[i] = tinf[i] + Tp[i];
    
  }
  
  for(i in 1:N){
    
    Ct[i] = CtVinffun(t[i], tinf[id[i]], lod, Vinf, tp[id[i]], Tp[id[i]], Tc[id[i]], Vp[id[i]]);
    
    if (t[i] < tinf[id[i]]){
      num_arg[i,1] = log_is_obs[i] + log(0.0002); // likelihood that Ctobs value < LOD (false positive PCR test)
      num_arg[i,2] = log_is_lod[i] + log(0.9998); // likelihood that Ctobs value == LOD (true negative PCR test)
    } else{
    num_arg[i,1] = log_is_obs[i] + normal_lpdf(y[i] | Ct[i], sigma); // likelihood that Ctobs value < LOD (positive PCR test)
    num_arg[i,2] = log_is_lod[i] + normal_lcdf(10 | Ct[i], sigma); // likelihood that Ctobs value == LOD (negative PCR test)
      
    }
    // If the observation is not at the lod value, num_arg[i,2] is equal to -inf and will not be taken in the likelihood : target += log_sum_exp(num_arg[i]);
    // Because log(exp(log(0)) + exp(num_arg[i,1])) = num_arg[i,1]
  }
}


model{
  // Priors //
  
  mu_Vp ~ normal(25, 5) T[10,50]; // hierarchical mean (mu)
  mu_Tp ~ normal(6, 2) T[0,]; // hierarchical mean (mu)
  mu_Tc ~ normal(15, 5) T[0,]; // hierarchical mean (mu)
  mu_Tincub ~ normal(5, 1); // hierarchical mean (mu)
  
  sigma ~ normal(0, 1) T[0,]; // mesurement error
  
  to_vector(eta_tilde) ~ normal(0,1);
  
  eta_omega ~ normal(0,1) T[0,]; // variance of random effect
  
  // Likelihood (looped on each observation) //
  for(i in 1:N){
    target += log_sum_exp(num_arg[i]);
  }
}

generated quantities{
  
  vector[N] log_lik;
  
  for (i in 1:N){
    log_lik[i] = log_sum_exp(num_arg[i]);
  }
  
  real tclear[n_id];
  
  for (j in 1:n_id){
    
    tclear[j] = tp[j] + Tc[j];
  }
}
1 Like

Hello,

I don’t have a ready-made solution but I have to points of inquiry:

  1. You seem to be saying that the problem comes from the initial values. These values are different for each of your chains. If you take the initial values that led to a flat chain, and impose them on a new run for your 4 chains. Do you end up with 4 flat chains?

  2. Have you tried loosing your priors? Your residual error has a pretty restricted prior (sigma ~ normal(0, 1) T[0,]; // measurement error). Maybe that’s important enough for your values, but it mights help the convergence to use a larger prior for your residual error.

And out of curiosity, what do your “T[0,];” or “T[10,50];” mean after your priors?

Those are truncated distributions. But I believe these are redundant since Stan already applies the necessary transformations based on the parameters’ declared bounds in the parameters block.
So for example, since sigma is correctly declared to be strictly positive with lower=0, Stan already truncates the normal(0, 1) prior at zero (and automatically applies a Jacobian adjustment for the transformation induced by this constraint), so the T[0, ] should be redundant.
You only really need the explicit truncation if you’d like to apply truncation that differs from the constraints in the parameters block.

1 Like

Not sure if this will solve your problems, but your model might benefit from reparameterising the population-level means, especially mu_Vp and mu_Tincub which have quite specific bounds. Instead of directly estimating the population-level means in their constrained space, you could first estimate an unconstrained parameter and then transform that to the desired bounds in the transformed parameters block, along the following lines:

parameters {
  real mu_Vp_pr;
  real mu_Tp_pr;
  real mu_Tc_pr;
  real mu_Tincub_pr;
}
transformed parameters {
  real mu_Vp = 10 + (50 - 10) * inv_logit(mu_Vp_pr);
  real mu_Tp = exp(mu_Tp_pr);
  real mu_Tc = exp(mu_Tc_pr);
  real mu_Tincub = 14 * inv_logit(mu_Tincub_pr);
}
model {
  mu_Vp_pr ~ std_normal();
  mu_Tp_pr ~ std_normal();
  mu_Tc_pr ~ std_normal();
  mu_Tincub_pr ~ std_normal();
}

The transformations that I’ve used are explained here for the lower-bounded parameters and here for the lower- and upper-bounded parameters.

The main downside of this approach is that the priors are defined in the unconstrained space, which is less intuitive. Here I’ve just used the standard normal distribution as a generic prior as it’s easy to sample from. But if the exact shape of the prior is important, you’d need to carefully examine what parameter values are implied by the prior following the transformation.

1 Like

Hello,

Thank you for your reply. It seems that this transformation allows me to free myself from the strong constraints associated with my parameters.
However, I still have a chain that diverges from the 3 others and slows down the run considerably (20x more time) by having the 4 chains start from the simulation value. I don’t think this is due to misspecification of the model because the other 3 chains very quickly explore a space around the true value (simulation value).
Do you have any further clues?

Thanks again,
Maxime

Nothing springs to mind I’m afraid - the context of your project is very different from my area of research. Hopefully others can pitch in as well!

Thank you anyway for your advices and your time !

Maxime

This almost always arises from varying curvature in the posterior. That is, one region of the curvature will have very high curvature and require very small steps and another region will have very low curvature and require bigger steps. If you don’t set a small enough step size, when it gets into the region requiring small steps, it will get stuck. The long runs that don’t get stuck are probably biased due to not exploring the high curvature region.

This an also happen on initialization from initializing too far out in the tails where everything doesn’t work well. You might try initializing with inits=0.5 instead of the default inits=2, which will cause initialization to be uniform on (-0.5, 0.5) rather than (-2, 2).

The other way this can happen is if your transformed parameters violate constraints, such as assigning a value to Vp that isn’t between 10 and 50.

The best thing to do in these situations is reparameterize the model to reduce the curvature.

That’s right in this case because the distributions being truncated have constant arguments like normal(0,1) T[0,]. In that case, the truncation contributes a constant term and can be dropped in the Stan code.

Whenever you see exp(log(a) + b), it’s probably better to replace with a * exp(b)—the case where you don’t want to do this is where exp(b) overflows or underflows but exp(log(a) + b) does not.

Is the + log(2e-4) for numerical stability? If so, I’d recommend removing that. If you need to do this kind of thing to stabilize a model, there are usually larger problems.

Doing straight comparisons of floating point numbers like this y[i]==10 is risky because of the way they’re represented internally. Usually you want to test if if it’s within epsilon of 10. Ideally, you’d include a boolean flag so as not to have to compare floating-point numbers.

Unless the t, ting, and tp are constants in the call to CtVinffun, you run into the problem of introducing discontinuities, which can also cause problems with HMC sampling.

Thank you Bob for all your remarks.

I don’t understand how the posterior form can be so complex. This model is just a ‘simple’ linear mixed effect model.
How do you check the curvature in the posterior, by plotting the log-likelihood of each sample?

I initialised my chains by simpling in my prior and as you can see from the traces, the 4 chains (even the divergent ones) are in very close proximity to the true values (where the chains that mix well gravitate).


I don’t get any messages telling me that certain parameters or transformed parameters are violating their constraints (which I had before but which I managed to resolve).

Thanks for the advice, I’m going to clean up my code if it doesn’t have an impact on inference.

We introduced this part because when we had several negative tests at the very start of the predicted viral kinetics, Stan was unable to correctly determine the time of onset of infection tinf. With this formula, we take into account in the likelihood the fact of taking a good look at the viral loads linked to an infection.

This is the same as checking whether a variable called ‘censor’ is equal to 1 or 0, because if the observation is not quantifiable/detectable it is hard coded as being equal to the censor value (in this case 10).

It is true that piecewise linear models give rise to discontinuity problems. But it seems to me that the use of log_sum_exp() gets around this problem, am I wrong?

Maxime

An approach I would suggest is to use Pathfinder initialization values. With cmdstanr, you first run a Pathfinder fit and then use the resulting object as the init object for sampling.

Those are the worst. The random effects prior \alpha_i \sim \text{normal}(0, \sigma) introduces seriously varying curvature—as \sigma \rightarrow 0 it drives all the \alpha_i to low scale and vice versa as \sigma \rightarrow \infty. In the limiting case of a hierarchical model with no data, Stan (HMC, NUTS) can’t even fit this with a fixed step size.

The easiest thing to do is to use BridgeStan to calculate Hessians. From those, you can assess both positive-definiteness everywhere and if positive-definite, condition number (ratio of max to min eigenvalue). When the condition number is bad, sampling in HMC is hard.

I’m suggesting going forward that you initialize with draws from chains that mixed. This may lead to bias, but you’re not seeing chains start mixing then lock up. You might want to plot lp__ to see if the stuck chains are stuck out in the tails.

I think I buried my point. The issue is that comparing floating point numbers to integers is a perilous operation in Stan or any other programming language because floating point doesn’t perfectly capture integers. It’s probably OK if you are only rounding a floating point to an integer in the same way every time, but even that can be problematic. Here’s an example I found online (you can look up why you don’t compare floating point numbers); this example uses R, but it behaves the same way in C++ or Python:

> 0.3 * 3 + 0.1 == 1
[1] FALSE

So I wasn’t suggesting changing the model, just supplying an integer you could test against, which should be much more reliable (what you have may be OK, but you really don’t want to rely on integer vs. floating point comparisons).

I’m not sure how log-sum-exp could fix a discontinuity problem. Log-sum-exp is just the operation for adding values when they’re given on the log scale. That is,

log_sum_exp(log p1, log p2) = log(p1 + p2).