Convergence issues with model for predicting survival counts based on increasing treatments

Hi everyone,

I am working to develop a model for counts of sampled individuals surviving a pesticide application with increasing doses. The model also accounts for differing survival rates to the pesticide based on three genotypes (AA, Aa, aa) that are expressed in different proportions of the population. I’ve attached the R script and data, and included the model below:

data {
  int<lower=0> n_obs;
  // vector[n_obs] dose;
  // vector[n_obs] n_alive;
  // vector[n_obs] n_tested;
  real dose[n_obs];
  int n_alive[n_obs];
  int n_tested[n_obs];
}

parameters {
  // background survival rate
  real<lower=0> S_0;
  // population parameter
  real p_pop;
  // genotype-specific mortality rates (per capita, per dose)
  vector<lower=0>[n_obs] r_AA;
  vector<lower=0>[n_obs] r_Aa;
  vector<lower=0>[n_obs] r_aa;
}

model {
  // proportion of population with genotype
  real P_AA_pop;
  real P_Aa_pop;
  real P_aa_pop;
  
  // survival probabilities based on genotype and dose
  vector[n_obs] S_AA;
  vector[n_obs] S_Aa;
  vector[n_obs] S_aa;
  
  // average survival probability of tested individuals
  vector[n_obs] S_bar;

  // priors
  S_0 ~ beta(0.5, 0.5);
  p_pop ~ beta(0.5, 0.5);
  log(r_AA) ~ cauchy(0, 1);
  log(r_Aa) ~ cauchy(0, 1);
  log(r_aa) ~ cauchy(0, 1);
  
  // target += -log(r_AA);
  // target += -log(r_Aa);
  // target += -log(r_aa);
  
  // proportion of population with genotype
  P_AA_pop = p_pop^2;
  P_Aa_pop = 2*p_pop*(1 - p_pop);
  P_aa_pop = (1 - p_pop)^2;
  
  // survival to exposure based on genotype and dose
  for (i in 1:n_obs){
    S_AA = S_0*exp(-r_AA*dose[i]);
    S_Aa = S_0*exp(-r_Aa*dose[i]);
    S_aa = S_0*exp(-r_aa*dose[i]);
  }
  
  // average survival probability of tested individuals
  S_bar = S_AA*P_AA_pop + S_Aa*P_Aa_pop + S_aa*P_aa_pop;

  // likelihood
  for (i in 1:n_obs){
    n_alive[i] ~ binomial(n_tested[i], S_bar);
  }
}

While I have gotten this model to run, I receive a large number of warnings about convergence issues which are not fixed by following the linked suggestions in the warning. Suspecting that the issue may be due to the log-transformed r_AA parameters, I’ve also tried the suggestions linked here, but encountered the same warnings from Stan.

I am not a proficient/advanced Stan user, so maybe I am missing something. Any suggestions on how to reformulate the model for improved inference would be much appreciated. Thank you!

analysis.R (1.8 KB)
pop_data.csv (531 Bytes)

  1. I couple of things I noticed with the model. S_bar is a probability but it is not constrained to be between 0 and 1. Are the warnings about the parameters of the binomial not being in the 0-1 interval? In addition, in the likelihood you probably mean to use the index i for S_bar. Like so,
for (i in 1:n_obs){
    n_alive[i] ~ binomial(n_tested[i], S_bar[i]);
  }
  1. You commented out the Jacobian adjustment but I think you still need them (// target += -log(r_AA) and so on). Do you get warnings about the Jacobian adjustment?

  2. I also don’t believe in those cauchy priors for log(r_AA). I don’t know what the scale is of dose but if dose[i] = 1, than based on the prior there is a 10% probability that S_AA > 19.

1 Like

Thanks for your quick reply, @stijn . Based on your suggestions, I added the index to S_bar in the likelihood and uncommented the target += -log(r_AA); lines. This still leads to:

DIAGNOSTIC(S) FROM PARSER:
Info:
Left-hand side of sampling statement (~) may contain a non-linear transform of a parameter or local variable.
If it does, you need to include a target += statement with the log absolute determinant of the Jacobian of the transform.
Left-hand-side of sampling statement:
    stan::math::log(r_AA) ~ cauchy(...)
Info:
Left-hand side of sampling statement (~) may contain a non-linear transform of a parameter or local variable.
If it does, you need to include a target += statement with the log absolute determinant of the Jacobian of the transform.
Left-hand-side of sampling statement:
    stan::math::log(r_Aa) ~ cauchy(...)
Info:
Left-hand side of sampling statement (~) may contain a non-linear transform of a parameter or local variable.
If it does, you need to include a target += statement with the log absolute determinant of the Jacobian of the transform.
Left-hand-side of sampling statement:
    stan::math::log(r_aa) ~ cauchy(...)

hash mismatch so recompiling; make sure Stan code ends with a blank line

which is expected, but should be corrected with the changes I made. Once sampling begins, the chains throw errors similar to:

Chain 2: Rejecting initial value:
Chain 2:   Error evaluating the log probability at the initial value.
Chain 2: Exception: beta_lpdf: Random variable is 1.56586, but must be less than or equal to 1  (in 'modele5c2afeb123c_model' at line 38)

Once sampling is finished, the following warnings are returned:

Warning messages:
1: There were 2987 divergent transitions after warmup. See
https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them. 
2: There were 13 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
https://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded 
3: There were 2 chains where the estimated Bayesian Fraction of Missing Information was low. See
https://mc-stan.org/misc/warnings.html#bfmi-low 
4: Examine the pairs() plot to diagnose sampling problems
 
5: The largest R-hat is 3.04, indicating chains have not mixed.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#r-hat 
6: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#bulk-ess 
7: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#tail-ess

I don’t think Stan “knows” that you fixed the problem by including the statements. That’s why it is giving a warning.

I think the problem here is that because S_0 and p_pop are not initialized to be between 0 and 1 and Stan is warning you that they need to be with a beta distribution. That would mean that you need

real<lower=0, upper=1> S_0;
real<lower=0, upper=1> p_pop;

I would not be surprised if all the other problems stem from the fact that S_bar is not contained between 0 and 1. Would it make sense to use an inverse logit transformation instead of an exponential transformation in the following part of the code:?

  for (i in 1:n_obs){
    S_AA = S_0*exp(-r_AA*dose[i]);
    S_Aa = S_0*exp(-r_Aa*dose[i]);
    S_aa = S_0*exp(-r_aa*dose[i]);
  }
1 Like

I think the problem here is that because S_0 and p_pop are not initialized to be between 0 and 1 and Stan is warning you that they need to be with a beta distribution. That would mean that you need

Thanks, I’ve added the lower/upper bounds to avoid this warning now!

I would not be surprised if all the other problems stem from the fact that S_bar is not contained between 0 and 1. Would it make sense to use an inverse logit transformation instead of an exponential transformation in the following part of the code:?

Last night I tested the model with inv_logit instead of exp for the survival rates. It still produced ~500 divergent transitions, and I was running Stan with

  m <- stan(file="model.stan", data=data.list, chains=2, iter=20000, warmup=500,
            thin=10, control=list("adapt_delta" = 0.99, "max_treedepth" = 15))

I’m pretty sure the sample size is overkill for HMC, but I think the model needs further optimization to converge more efficiently. I will try running the model next with vector<lower=0, upper=1>[n_obs] S_bar and see how it performs.

So, I think the priors are too wide and I would start with running some simulations to get to a better idea of what the model is implying. This R code should give you some idea of what you can do with prior predictive checks. This is just for one observation with dose = 1 and n_tests = 100. These values might not be realistic.


generate_prior_observation <- function(dose = 1, n_tested = 100) {
  S_0 <- rbeta(1, .5, .5)
  p_pop <- rbeta(1, .5, .5)
  logit_rAA <- rnorm(1, 0, 1)
  logit_rAa <- rnorm(1, 0, 1)
  logit_raa <- rnorm(1, 0, 1)
  SAA <- S_0 * plogis(-logit_rAA * dose)
  SAa <- S_0 * plogis(-logit_rAa * dose)
  Saa <- S_0 * plogis(-logit_raa * dose)
  S_bar <- p_pop^2 * SAA + 2 * p_pop * (1 - p_pop) * SAa + (1 - p_pop)^2 * Saa
  n_alive <- rbinom(1, n_tested, S_bar)
  return(list(SAA = SAA, SAa = SAa, Saa = Saa, S_0 = S_0, p_pop = p_pop, n_alive = n_alive))
}

sim <- replicate(1000, generate_prior_observation(), simplify = "matrix")
sim <- as.data.frame(t(apply(sim, 1:2, unlist)))
hist(sim$S_0)
hist(sim$SAA)
hist(sim$n_alive)

There are also some index bugs with other parts of the code. In your model, you can avoid those issues by not using a loop. Based on my very preliminary prior predictive simulations above I would start with the following model block. However, I think the prior are probably still to diffuse.

model {
  // proportion of population with genotype
  real P_AA_pop = p_pop^2;
  real P_Aa_pop = 2*p_pop*(1 - p_pop);
  real P_aa_pop = (1 - p_pop)^2;
  
  // survival probabilities based on genotype and dose
  vector[n_obs] S_AA;
  vector[n_obs] S_Aa;
  vector[n_obs] S_aa;
  
  // average survival probability of tested individuals
  vector[n_obs] S_bar;

  // priors
  S_0 ~ beta(0.5, 0.5);
  p_pop ~ beta(0.5, 0.5);
  logit_r_AA ~ normal(0, 1);
  logit_r_Aa ~ normal(0, 1);
  logit_r_aa ~ normal(0, 1);

  // survival to exposure based on genotype and dose
  S_AA = S_0*inv_logit(-logit_r_AA*dose);
  S_Aa = S_0*inv_logit(-logit_r_Aa*dose);
  S_aa = S_0*inv_logit(-logit_r_aa*dose);
  
  // average survival probability of tested individuals
  S_bar = S_AA*P_AA_pop + S_Aa*P_Aa_pop + S_aa*P_aa_pop;

  // likelihood
  n_alive ~ binomial(n_tested, S_bar);

}

I don’t think this functional form makes much sense: at dose=0 there’s no pesticide and survivability shoud be 100% but logit model predicts 50%. Maybe instead of multiplying by dose, add log-dose

  S_aa = S_0*inv_logit(-log_LD50_aa - log(dose));

which corresponds to mathematical model

\mathrm{survivability} = \frac{\mathrm{LD}_{50}}{\mathrm{LD}_{50}+\mathrm{dose}}

Alternatively, the original exponential model makes sense mathematically but Stan implements positive-constrained parameters with exponential transform so the model effectively has a very steep double-exponential transform. It may be better to reparametrize in terms of the base of the exponent rather than the rate parameter, something like

parameters {
  vector<lower=0, upper=1>[n_obs] S1_aa;
}
model {
  vector[n_obs] log_r_aa = log(-log(S1_aa));
  log_r_aa ~ cauchy(0, 1); // prior, should probably be normal(0,1)
  target += -log_r_aa - log(S1_aa); // Jacobian for log_r_aa
  vector[n_obs] S_aa = S_0*S1_aa^dose;
}
2 Likes

Yeah. You are right. I did not think this through at all.

1 Like

Thank you for your replies. Just for others to follow, I tested the model based on the following changes suggested by @stijn and @nhuurre:

  • change the priors for the mortality rates (r_**) from the Cauchy to normal distribution (I’m curious why the Cauchy prior would be too wide, as @stijn mentioned?)
  • mortality rate priors are no longer log-transformed and target += -log(r_AA); statements are removed
  • the survival rates are calculated as S_0*inverse_logit(-logit_r_AA - log(dose)) instead of the original formulation S_0 * exp(-r_AA * dose), and are vectorized instead of being calculated in a loop

The updated model with these changes looks like this:

data {
  int<lower=0> n_obs;  // number of observations
  vector[n_obs] dose;  // dosage per observation
  int n_alive[n_obs];  // number of individuals alive
  int n_tested[n_obs]; // number of individuals tested
}

parameters {
  // background survival rate
  real<lower=0, upper=1> S_0;
  // population parameter
  real<lower=0, upper=1> p_pop;
  // genotype-specific mortality rates (per capita, per dose)
  vector<lower=0>[n_obs] logit_r_AA;
  vector<lower=0>[n_obs] logit_r_Aa;
  vector<lower=0>[n_obs] logit_r_aa;
  // // average survival probability of tested individuals
  // vector<lower=0, upper=1>[n_obs] S_bar;
}

model {
  // proportion of population with genotype
  real P_AA_pop = p_pop^2;
  real P_Aa_pop = 2*p_pop*(1 - p_pop);
  real P_aa_pop = (1 - p_pop)^2;
  
  // survival probabilities based on genotype and dose
  vector[n_obs] S_AA;
  vector[n_obs] S_Aa;
  vector[n_obs] S_aa;
  
  // average survival probability of tested individuals
  vector[n_obs] S_bar;

  // priors
  S_0 ~ beta(0.5, 0.5);
  p_pop ~ beta(0.5, 0.5);

  logit_r_AA ~ normal(0, 1);
  logit_r_Aa ~ normal(0, 1);
  logit_r_aa ~ normal(0, 1);

  // survival to exposure based on genotype and dose
  S_AA = S_0*inv_logit(-logit_r_AA - log(dose));
  S_Aa = S_0*inv_logit(-logit_r_Aa - log(dose));
  S_aa = S_0*inv_logit(-logit_r_aa - log(dose));
  
  // average survival probability of tested individuals
  S_bar = S_AA*P_AA_pop + S_Aa*P_Aa_pop + S_aa*P_aa_pop;

  // define likelihood of surviving counts given the number tested and survival probability
  n_alive ~ binomial(n_tested, S_bar);
}

The previous iteration of the model took about ~3h to sample and still reported ~500 divergent transitions with:

  #' https://mc-stan.org/misc/warnings.html
  m <- stan(file="model.stan", data=data.list, chains=2, iter=20000, warmup=500,
            thin=5, control=list("adapt_delta" = 0.99, "max_treedepth" = 17))

The current version of the model requires 30k samples to avoid warnings about high R-hat and low bulk ESS, but the model now runs in ~30s and doesn’t throw any warnings.

There are smarter people than me on this forum that will be able to explain things better. Nevertheless, my intuition is that a cauchy distribution has heavy tails which are then exacerbated by the exponent (also with inverse logit). Cauchy + exp()/inv_logit() has become a bit of red flag for me. These priors always end up with a high probability on extreme values which is typically not what you want. The best way to see the impact is just to simulate from the prior and look at a histogram.

1 Like