One chain not moving: sampling for complex nonlinear mixed effects model

Short summary of the problem:

Hi…I am fitting a nonlinear multilevel model to analyze survival probability of individuals (0/1 response) in relation to two continuous predictors CON and HET. We know that the relationship follows the form:

y[i] = b0[i,j] + b1[i,j]* CON^c1[i,j] + b2[i,j]* HET^c2[i,j]

Here, y is the log odds of survival and b0, b1, b2, c1 and c2 are parameters that have to be estimated from the data. Individuals (i) are distributed among j groups. All coefficients vary with group. In addition, all coefficients are linear functions of another predictor ‘AB’., i.e., each b/c ~ AB + random effects for groups.

This is a large dataset of ~150,000 data points and j = 190 groups. I formulated the model as shown below:

# set priors
priors = c(
  set_prior("normal(0, 1)", nlpar = "b0"),
  set_prior("normal(0, 1)", nlpar = "b1"),
  set_prior("normal(0, 1)", nlpar = "b2"),
  set_prior("normal(0, 1)", nlpar = "c1"),
  set_prior("normal(0, 1)", nlpar = "c2")
)
cores = 6
chains = 3
control = list(adapt_delta = 0.9)
reuse_models = TRUE

fit_test = brm(
  bf(status ~ b0 + (b1 * dBA.con.15m^c1) + (b2 * dBA.het.15m^c2),
     b0 ~ 1 + abundance + (1|sp) ,#+(1|quadrat), # (1|year) + 
     b1 ~ 1 + abundance + (1|sp), 
     b2 ~ 1 + abundance + (1|sp), 
     c1 ~ 1 + abundance + (1|sp), 
     c2 ~ 1 + abundance + (1|sp),
     nl = TRUE), 
  data = dat.test, family = bernoulli(link = "logit"), 
  prior = priors, save_model = "brm_cndd_bci.txt",
  chains = chains, cores = cores, seed = 123,
  control = control, inits = 0
)

When I tried this model with a small subset of the data, running 4 chains for 2000 iterations, the model converged and seemed to give sensible parameter estimates. However, when I run the model with the full data or larger subsets (30%. 60%, 80% of data points), one of the chains consistently ‘hangs’. Sampling stops usually just after warmup is done for this chain. I have let the ‘hung’ chain continue for up to 4 hours, with no progress. I have not been able to fix the problem even with running simpler versions of the model or using smaller subsets of the data.

  1. Have I specified the nonlinear model correctly in brms?
  2. I just moved to a new laptop (Windows 10) and updated to R version 4.0.2 and brms version 2.13.5. I was initially having problems with even starting the sampling , which got resolved by uninstalling and reinstalling brms. Could this be a sampling problem with rstan?

I hope this information suffices to explain the problem. Would much appreciate your help!

Thanks,
Meghna

1 Like

When one chain hangs, you can check its target density, as sometimes chains get stuck in local modes that are far from what we’re calling the typical set. A simple solution in such cases is to discard the bad chains; a better approach would be stacking, as here: http://www.stat.columbia.edu/~gelman/research/unpublished/2006.12335.pdf

It would be good to know why the chain is getting stuck. In many cases, the problem can be solved using stronger priors, especially for the group-level variance parameters. If you can include stronger priors in the context of your model, I’d recommend doing that in any case. Actually I’d recommend doing that first, before any of the other steps.

Hi Andrew,

Thanks for your suggestions. I re-ran the models with tighter priors on the parameters based on preliminary runs with a small subset of the data (the one that had converged). These do better in that the chains now progress to termination. But, chains take vastly different times to finish and do not converge - R-hats are very high (>2.5) even with 2000 iterations. Would more iterations help in this case? How should I identify and discard bad chains? Apologies if this is a super basic question, but I could not find this information in the brms help/vignettes.

Were you also suggesting I have more informed hyperpriors for group-level variance parameters? Am I right in understanding that brms does not allow this and I would have to do this directly in Stan? I ask because I am not confident of coding this complex model directly in Stan. I could adapt the code generated through brms, but would prefer to implement through brms directly.

Also, not sure if it is helpful, but I get the following message when I abort the models for which chain does not progress:

Warning message:
In system(paste(CXX, ARGS), ignore.stdout = TRUE, ignore.stderr = TRUE) :
  '-E' not found
  1. You as, “How should I identify and discard bad chains?” I recommend stacking: http://www.stat.columbia.edu/~gelman/research/unpublished/2006.12335.pdf

  2. It seems that in brms you can set priors to group-level variance parameters; see item 2 on this help page: https://rdrr.io/cran/brms/man/set_prior.html
    I would try this first, as it might resolve your convergence problems.

  3. I have no idea about the warning message!

1 Like

I agree with Andrew that figuring out why the model has problems should be your first concern and only when you can’t make progress there would I consider using stacking to get some inferences even from a non-coverging model.

I have drafted a list of strategies to handle models with divergent transitions that also (partially) apply to the case of high R-hats Divergent transitions - a primer Would be cool if you checked it out and let us know if:

  • You are able to understand the strategies given
  • You are able to determine which of those can be applied to your case
  • Some of them eventually let you move forward

I’ve previously had premature termination of various R computations result in weird warnings, so I wouldn’t worry about the warning much.

Best of luck with your model!

1 Like

Hi Andrew and Martin,

Thank you for your inputs. I will try these suggestions and revert in a few days. @martinmodrak: I could understand the strategies conceptually, but still struggling with finding the right priors to solve my problem.

Based on a simpler model with no exponents, I tried setting tighter priors for all parameters including the group-level variances, but that has not helped. The last chain still goes on for >24 hrs after the other two chains have finished sampling. I am now looking into whether some distributions work better than others as prior for group-level SDs - please let me know if you have any suggestions.

I was also unclear about whether, for such complex models, it is OK to set informative priors for each variance parameter separately. For instance, should I be defining a prior for each of the ‘sd’ below or can I just define one prior for the ‘sd’ of each ‘nlpar’:

get_prior(bf(status ~ b0 + (b1 * BA.con.15m^c1) + (b2 * BA.het.15m^c2),
+              b0 ~ 1 + abun.s + (1|sp) +(1|quadrat), # (1|year) + 
+              b1 ~ 1 + abun.s + (1|sp), #(1|X|sp)
+              b2 ~ 1 + abun.s + (1|sp), 
+              c1 ~ 1 + (1|sp), #(1|X|sp)
+              c2 ~ 1 + (1|sp),
+              nl = TRUE), 
+           data = dat.sub, family = bernoulli(link = "logit"))
                  prior class      coef   group resp dpar nlpar bound
1                           b                                b0      
2                           b    abun.s                      b0      
3                           b Intercept                      b0      
4  student_t(3, 0, 2.5)    sd                                b0      
5                          sd           quadrat              b0      
6                          sd Intercept quadrat              b0      
7                          sd                sp              b0      
8                          sd Intercept      sp              b0      
9                           b                                b1      
10                          b    abun.s                      b1      
11                          b Intercept                      b1      
12 student_t(3, 0, 2.5)    sd                                b1      
13                         sd                sp              b1      
14                         sd Intercept      sp              b1      
15                          b                                b2      
16                          b    abun.s                      b2      
17                          b Intercept                      b2      
18 student_t(3, 0, 2.5)    sd                                b2      
19                         sd                sp              b2      
20                         sd Intercept      sp              b2      
21                          b                                c1      
22                          b Intercept                      c1      
23 student_t(3, 0, 2.5)    sd                                c1      
24                         sd                sp              c1      
25                         sd Intercept      sp              c1      
26                          b                                c2      
27                          b Intercept                      c2      
28 student_t(3, 0, 2.5)    sd                                c2      
29                         sd                sp              c2      
30                         sd Intercept      sp              c2   

Would setting initial values for parameters help?

Might the nonlinear model work better if coded as an inverse logit function with status ~ inv.logit(b0 + (b1 * BA.con.20m^c1) + (b2 * BA.het.20m^c2)) and family = bernoulli(link = 'identity')?

Thanks,
Meghna

1 Like

Hi,

Writing with an update. As suggested by @andrewgelman and @martinmodrak, I reformulated my model using [what I thought as] tighter priors. I obtained these prior expectations from the simpler version of the model, i.e., without exponents, which had converged well. For the exponents, I set up priors based on values from similar analyses done previously. I set up the priors such that their modes reflected expected values. Here are the priors I used:

> prior_summary(fit_1)
               prior class      coef   group resp dpar nlpar
1                        b                                b0
2  normal(-0.2, 0.5)     b    abun.s                      b0
3     normal(2, 0.5)     b Intercept                      b0
4                        b                                b1
5     normal(0, 0.5)     b    abun.s                      b1
6  normal(-0.5, 0.5)     b Intercept                      b1
7                        b                                b2
8     normal(0, 0.5)     b    abun.s                      b2
9   normal(0.1, 0.5)     b Intercept                      b2
10  uniform(0.1,0.5)     b                                c1
11                       b Intercept                      c1
12  uniform(0.1,0.5)     b                                c2
13                       b Intercept                      c2
14        gamma(2,1)    sd                                b0
15        gamma(2,1)    sd                                b1
16        gamma(1,1)    sd                                b2
17      gamma(1,0.5)    sd                                c1
18      gamma(2,0.5)    sd                                c2
19                      sd           quadrat              b0
20                      sd Intercept quadrat              b0
21                      sd                sp              b0
22                      sd Intercept      sp              b0
23                      sd                sp              b1
24                      sd Intercept      sp              b1
25                      sd                sp              b2
26                      sd Intercept      sp              b2
27                      sd                sp              c1
28                      sd Intercept      sp              c1
29                      sd                sp              c2
30                      sd Intercept      sp              c2

… and here is the model:

> summary(fit_1)
 Family: bernoulli 
  Links: mu = logit 
Formula: status ~ b0 + (b1 * BA.con.15m^c1) + (b2 * BA.het.15m^c2) 
         b0 ~ 1 + abun.s + (1 | sp) + (1 | quadrat)
         b1 ~ 1 + abun.s + (1 | sp)
         b2 ~ 1 + abun.s + (1 | sp)
         c1 ~ 1 + (1 | sp)
         c2 ~ 1 + (1 | sp)
   Data: dat.test (Number of observations: 43161) 
Samples: 3 chains, each with iter = 2500; warmup = 1000; thin = 2;
         total post-warmup samples = 2250

Group-Level Effects: 
~quadrat (Number of levels: 1250) 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(b0_Intercept)     0.41      0.03     0.36     0.46 1.05       77      383

~sp (Number of levels: 73) 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(b0_Intercept)     0.84      0.27     0.18     1.19 1.19       13       96
sd(b1_Intercept)     1.10      0.29     0.58     1.74 1.03      178      845
sd(b2_Intercept)     0.37      0.31     0.01     0.98 1.29        9       60
sd(c1_Intercept)     0.11      0.06     0.01     0.23 1.04       79      137
sd(c2_Intercept)     0.53      0.37     0.08     1.36 1.14       17      522

Population-Level Effects: 
             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
b0_Intercept     2.23      0.23     1.78     2.69 1.04      209      214
b0_abun.s       -0.01      0.29    -0.65     0.53 1.03      241      438
b1_Intercept    -0.81      0.24    -1.27    -0.31 1.04      351      273
b1_abun.s        0.31      0.27    -0.24     0.84 1.04      117      236
b2_Intercept    -0.05      0.19    -0.49     0.32 1.11      329      163
b2_abun.s       -0.05      0.25    -0.52     0.51 1.10      194      458
c1_Intercept     0.44      0.05     0.31     0.50 1.04       80      789
c2_Intercept     0.20      0.09     0.10     0.43 1.03      126      748

Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

With this, the chains do not hang anymore even with half the dataset being used. However, sampling takes quite long and convergence issues remain with R-hats as large as 1.3 for some parameters. Trace plots suggest poor mixing of chains for these parameters (as expected) and there are still many divergent transitions (>1000). Increasing adapt_delta (up to 0.99) has not helped. I have tried by now many different specifications of priors for SDs, but have not been able to achieve full convergence.

So, my questions now:

  1. Would running more iterations help in this case, or is the problem more structural? Posterior predictive checks seem to show that the model fits the data well.
  2. Are trace plots and posterior densities the best way to discern which priors need to be tightened or is there a heuristic to figure out the best prior, especially for SDs, if the true values are not known? After a bunch of trials, I used gamma and it seems to work better than the suggested half Student-t.
  3. Might the nonlinear model work better if coded as an inverse logit function with
    status ~ inv.logit(b0 + (b1 * BA.con.15m^c1) + (b2 * BA.het.15m^c2)) and family = bernoulli(link = 'identity') ?

Just a note: I have implemented a similar nonlinear model in JAGS with fairly diffuse priors and it had converged successfully with sensible parameter estimates. So, I am not sure what I am doing wrong here.

I would be grateful for your inputs to help resolve this.

Thanks,
Meghna

1 Like

You could try manually specifying inits to reasonable values? Just a thought. Also just speaking generally I don’t like things such as normal(0.1, 0.5) priors. See discussion here: https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations and search on “uniform”.

Hi,
just to set some expectations: those non-linear models can be very tricky to get right. Inspecting priors is generally a good early step, but is quite possible the model has issues that can’t be resolved with better priors and a deeper dive into the math of the model and the data at hand can be necessary - little can AFAIK be said in full generality.

It is quite likely you are doing nothing wrong. Unfortunately, people converting JAGS models to Stan frequently found that models that converge in JAGS are later found problematic in Stan, because Stan has much more sensitive diagnostics (and the models were problematic also in JAGS, but nobody noticed). In some of those cases, the estimates from JAGS turned out to be substantially biased away from estimates from Stan after addressing the computational problems. It is possible this is one of those cases.

With divergences, it is almost always structural, I try to give some hints below, but really, those models are a rough terrain to navigate, so I am not very sure of any of the advice. If PP checks are good, than there is always the possiblity that the model is too flexible - data just does not inform some of the parameters (which might manifest as divergences)

I think that in such a complex model, at least a weakly informative prior should be present for all parameters - if I understand this correctly, you are not having a prior for many of the sds in the 1|sp terms?

Also I would consider not estimating the correlations and use 1||p - those can be simpler to fit.

 uniform(0.1,0.5)     b                                c1

We usually don’t recommend hard bounds on priors, if you think the coefficient is very unlikely to be outside 0.1 and 0.5, than prior like normal(0.3, 0.1) will soft-constrain the model in a similar way, but is usually much more friendly to the sampler.

Also, (potentially unintuitively), brms requires you to give explicit lower and upper bounds for parameters via the lb and ub parameters to set_prior (or similar function) Did you set that? If not, those would definitely cause problems.

Generally, no. Priors should be guided by domain knowledge - prior predictive checks are a good tool to extract this knowledge. One heuristic I’ve found useful is to check whether the posterior is “pushing against” the prior - se e.g. the c1_Intercept has the posterior hitting the 0.5 boundary. But with any prior fiddling you need to be careful to not bring more information into the model than what is actually reasonable.

I think this is very unlikely to help.

Best of luck with your model and if this doesn’t help/make sense, you are welcome to ask for further clarifications :-)

Hi Martin,

Thanks for the explanations and suggestions. I do not have the training to go deeper into the math of the model :). If it comes to that, I would very much appreciate help from experts in this forum.

I have read that estimates from JAGS were realized to be biased once tried on Stan. However, if my Stan/brms models do not behave, how do I know where the bias lies? Does this go back to your point about delving into the math?

I wasn’t entirely clear about your comment on weakly informative priors for all groups. Hence clarifying - are you suggesting that I assign a separate weak prior to all levels of the grouping variable for each parameter? That would be 163 levels x 4 parameters in the full dataset/model. Currently, I just specify one prior per SD, like so:

set_prior("gamma(2,2)", class = "sd", nlpar = "b0"),
  set_prior("gamma(2,1)", class = "sd", nlpar = "b1"),
  set_prior("gamma(1,0.5)", class = "sd", nlpar = "b2"),
  set_prior("gamma(1,0.5)", class = "sd", nlpar = "c1"),
  set_prior("gamma(1,0.5)", class = "sd", nlpar = "c2")

How would I do what you suggested? Sorry if I am missing something obvious, but I did a quick search before I wrote to you and could not find an example or help.

Also, how might I specify initial values in brm(..., inits = list(??)? Should the list of lists have one list for each chain, with each list in turn having a two-column df with parameters and associated values?

About hard bounds: I placed the bounds based on the output from JAGS. However, I had only specified lb and ub for parameters where the prior was assigned a uniform distribution. Is that what you were asking me? I did not set bounds for other priors.

For priors to be guided by domain knowledge: I can see how that works for population-level effects, but how would one have such information a priori for group-level effects/SDs?

Thanks,
Meghna

1 Like

Unfortunately, there is little that holds generally. Sometimes you even spend a lot of time resolving some pathology in the model and once you resolve it, the posterior stays almost the same. Other times your conclusions can completely reverse… Understanding the interplay between the mathematical formulation of the model and the data unfortunately can’t be avoided without significant risks.

Oh, sorry, you are right, this should be probably sufficient.

Yes, what I meant was that the uniform prior needs to be accompanied by the bounds, other priors shouldn’t need bounds. As I said, hard bounds are mostly discouraged, unless there is a “physical” bound (e.g. probability has to be betweeen 0 and 1, height is larger than 0, …).

No easy answers for such a complex model. There are a ton of interactions with the other parameters You can check out prior predictive checks. But, as I said, the problem is IMHO likely actually NOT in the priors.

Estimating sums of exponentials are actually a known hard problem - sometimes the data just can’t distinguish the two exponentials.

In any case I would start by reducing the model - does it work with just a single exponential? Does it work without the varying intercepts on the parameters? etc.

Hi Martin,

Thanks a lot for your comments. As you suggested, I simplified the models sequentially, starting with no random effects on the exponents and no hierarchical predictors for the parameters. These models behaved well. I used the posteriors from simpler models to inform my priors for the full model. I still had to run a few trials to tweak inidividual priors - it seemed like the model is very sensitive to prior specification for the group SDs. I also specified longer runs (3000 iterations, thin = 2).

So, I finally have a working model. The parameter estimates also seem reasonable and PP checks look OK. However, I still have a few divergent transitions (~200). Do let me know if you have any thoughts on this.

It does seem like the model [mis]behavior is determined largely by the data in this case. When I used the model for another data set, I had no issues at all! I tried some prior predictive checks to understand what might be the problem.

Of course, I may understand this better once I work through the math. Have not done so yet, but I hope to get to it at some point. For now, I am just being a grateful consumer of this great product :)

Thank you for your help!

Best wishes,
Meghna

1 Like

Unfortunately, divergences usually cannot be safely ignored.

This is actually quite frequently the case - not all datasets inform all the parameters of the model well (which can then result in divergences). For example, if BA.con.15m and BA.het.15m correlate well in one dataset, it may become impossible to distinguish how much does each of the exponentials contribute to the response - only their sum is identified.

Examining pairs plots for the population-level effects might help us diagnose this as it should show some weird correlations between the parameters as well if this is the case.

This might be an important hint (at is definitely a sign that something is wrong). Maybe this dataset has too few observations per each sp and quadrat combination to inform the varying intercepts well… If you cannot fit the model without narrowing the priors more than you can justify from domain knowledge, it probably means you need more data or need to simplify the model.

A smaller note:

I would not recommend that. Priors should generally be defensible by simply referring to domain knowledge, without any reference to the actual data you happened to collect. The process you describe makes the priors depend on the data and could thus introduce some bias into your inference (it could be small, but it is hard to say).

Hope this helps at least a little bit!

Hi Martin,

Sorry for the long hiatus. Resuming the discussion about dealing with divergences in a non-linear hierarchical model. You might recall that in trying to implement a multi-level non-linear models with brms, I was running into issues of persistent divergences. Assigning tighter priors and correlations for group-level reduced divergences to ~8 - 10%, but I am unable to further improve this. I read through the help documents you shared about divergences, but am at a loss how to proceed further.

The divergences are not happening in any particular region of the parameter space, they are scattered all over. The pairs() plots showed however that some population-level parameters were correlated - might that be the reason for why the divergences persist? I was told that it help to draw the population-level effects from a multivariate normal distribution, but am unable to implement it using the brms syntax for non-linear formulations as given below:

bf(Y ~ b0 + (b1 * X1^c1) + (b2 * X2^c2),
+              b0 ~ 1 + Z.s + (1|p|sp) +(1|q),+ 
+              b1 ~ 1 + Z.s + (1|p|sp), 
+              b2 ~ 1 + Z.s + (1|p|sp), 
+              c1 ~ 1 + (1|p|sp),
+              c2 ~ 1 + (1|p|sp),
+              nl = TRUE)

Specifically, I am unable to provide a multi_normal() prior, as the syntax appears to require a separate prior per parameter.

Also, as a general point - what is an acceptable % of divergences?

Thanks,
Meghna

No worries :-)

Yeah, it looks like setting a multivariate normal prior in this way is not possible. You would be able to do this in pure Stan. I however don’t think this is likely to help much. In any case you are unfortunately in a difficult territory - there is no universal algorithm to debug a pathological model… If you could share some of the problematic pairs plots we might be able to try some other guesses.

Actually a similar (but simpler) model is discussed as a pathological example in our recent preprint on workflow: [2011.01808] Bayesian Workflow (see Figure 12). A good check would be to see if your model can be fit well by just a single exponential. If you can’t find a posterior predictive check that would show a problem with the “single exponential” model, it is unlikely the data can inform the larger model…

Does that make sense?