Max_treedepth saturated, but increasing it slows sampling to a standstill

Hi all,

I’m fitting a big hierarchical nonlinear model, on ~18k points and a bunch of groups: about 12k parameters, and p_loo = 1879. It seems to be doing a great job: my fits look excellent, and the parameters seem reasonable. I’m busy working on simulated data to test it now, but based on the fits and the posterior predictive distribution, I’m not feeling concerned. I typically either get 0 or 1 divergences over 6000 iterations (and adapt_delta=0.90), so that’s not really flagging anything either. Also, I’m only really concerned with my population level fixed effects, so some degree of bias at the level of individual data points, or individual random effects is of little concern for me if it doesn’t greatly affect my fixed effects estimation.

However, I’m saturating my max_treedepth: 5999 of 6000 iterations reach the max_treedepth. It’s also taking about 4 or 5 days to fit. This is do-able, but it’s quite annoying for model-building, realising a week later that I need an extra variable. I tried turning up the max_treedepth to 15, and the sampling time is much, much slower. After 24 hours, I haven’t even reached 100 iterations. I don’t even know how to start looking into what’s happening with the higher treedepth, because it’ll just take so long to sample anything at all.

Preliminary questions:

  • Is saturating the treedepth something I should really even be concerned about? I understand that it’s about efficiency, but sampling for a month seems like it might not be the most efficient use of the sampling time.
  • Are there other ways that I should be attacking this problem without increasing the treedepth? Reparameterisation? Tightening priors? Other control parameters?

I read through the discussion here (Setting Max Treedepth in difficult high-dimensional models), which seems relevant. @dlakelan was artificially reducing the max_treedepth, and seeing significant speed-ups. And then they could later implement the full model with a better treedepth, or even deciding on a good compromise to optimise the number of effective samples based on the computing time. I suspect that my model shares some of the same characteristics: several very tightly-identified parameters with small variance, for which leaving the typical set takes us to badlands where probability goes to die. The parameters are also quite highly correlated in some cases, which probably complicates the geometry further. I’m actually running inits=0 to be rather more sure that the sampler starts within the typical set, as I had issues with an earlier (smaller) model, where one of the chains could become completely lost and have 100% divergent transitions out in the wilderness, while the others had none.

The discussion above comes from ~3 years ago, and there was some talk of saving warmup samples for reuse, or a “run for 8 hours” mode. It’s also around the upper limit of my level of understanding of how to think about how to resolve these kinds of issues, so maybe I’m not understanding something there correctly. Does the proposal of messing around with artificially lowering the max_treedepth, or just leaving it where it is, come with any other serious downsides that I should be considering? And, based on the comment before, are my saturated treedepths really so problematic after all if everything else is doing ok?

Thanks so much in advance for any help!

I’m using cmdstanr, with cmdstan v2.24.1 by the way.

Check out (Another) Slow Hierarchical Model in Stan

I’ve seen this behaviour in models where I’ve forgotten to put priors on some parameters, and the default uniform priors give some tricky posterior geometry that the sampler struggles with.

However this can come from a lot of different areas. If you can post your model code we’ll be able to provide some more concrete answers/help (hopefully!)

Sorry for taking so long to get back to you both! Haven’t had a second free the last couple of days. Thank you so much for the suggestions!

@spinkney : Looking into the QR decomposition, it sounds like this could be a really amazing quick fix! I’m using brms actually, and I see that there’s an option to automatically perform the QR decomp, so I’m trying that out now, and I’ll see how it goes. Fingers crossed! I’ll update when it’s done.

@andrjohns : As mentioned, I’m using brms for this model, and most of the default priors are really inappropriate for this model, so I was pretty sure I’d set priors over everything, but it could well be that I’ve missed something somewhere. Thank you so much for the incredibly generous offer to look it over! Below I’ve included the stancode of the model (though I removed the function call at the top), as well as the brms function calls, the prior specification and the output of get_prior().

stancode_slow_treedepth_question.txt (13.4 KB) brmscode_slow_treedepth_question.txt (17.5 KB)

Thank you so much in advance for any suggestions at all! O, and if you notice anything odd, do please feel free to mention it. I’m very much on a learning path here :).

1 Like

Update on the QR decomposition: I started it running it a hurry as I didn’t have much time last week. I just used the decomp = "QR" input argument in brms, and set it into motion. Seeing that it was not going any quicker than usual during the weekend, I looked at the STAN code, and I see that it was completely unchanged. I even performed a diff, and it was exactly the same. I looked more closely at the case study you linked, and I notice it says " … it will be applicable regardless of the choice of priors and for any general linear model." I presume that this does not apply to nonlinear models then? In any case, the posterior samples which will be highly correlated are for the parameters of the nonlinear model itself, rather than the linear predictors of the nonlinear parameters (i.e. nonlinear model = function(a,b,c) and linear predictors: a~age + factor1, b~age + sex + factor1 etc. - so in the model code, it’s logK1, logVnd, logBPnd, and logk4 which are highly correlated). So perhaps QR decomposition might not be such a quickfix solution after all.

Though I did learn about evaluating the number of leapfrogs per iteration, which seems to be 1023 for just about all iterations. I suppose this is limited by the treedepth, and would increase with higher values. But it’s becoming increasingly clear I really need to do more homework and take a deeper dive into learning more about the fundamentals of HMC. I’ve been meaning to find the time to really do this. I guess this week will be it! So I’ll update if I have any insights along that path.

Regardless, if anyone has any suggestions, or spots anything apparent in the model code that I should check out or try, I’d be more than happy to dig in and give it a try. Thanks so much in advance!