Speed up for GAMs

Hi,

I’d greatly appreciate any tips etc. that would make my multilevel GAMs to run faster (all the cores available should be on use due to “options(mc.cores = parallel::detectCores()”). I have some >92,000 observations and 429 levels of “c1” and “c2” defining the multimembership structure. I have tried more informative priors (e.g. using normal priors for “b” parameters), but in general I am unsure what sort of priors are commonly used in multilevel GAMs (I think the prior for parameters “sds” might be most important here)? Below is the brms code I am using (otherwise I am using default settings). Anything I could do other than just using a sample from my original data?

m1 <- brm(y ~ s(x1) + s(x2) + (1|mm(c1,c2)))

Thanks,
Samuli

Please also provide the following information in addition to your question:

  • Operating System: Win10 Pro
  • brms Version: brms 2.10.0
1 Like

I’m not sure, but the spline terms could be expensive. If you expect these to be not-rapidly-changing functions, you might be able to use approximate splines by providing a k argument to the s functions. s(x1, k = 20) says approximate the spline with 20 basis functions.

If you want to set your priors, you can use brms::get_prior to figure out your options are (I think ?brms::set_prior has lots of docs on actually setting them). If you’re curious how your model is implemented you can make the code with make_stancode. That might reveal what the big calculations are. 92000 data points is a lot.

I don’t think brms can do this yet, but for a dataset like that, you would want to integrate the group-specific intercepts out of the likelihood and use the optimized form of the likelihood

which means you would have to write your own Stan program to do it.

Hey @bgoodri! You mentioned this approach before and I was wondering if there was something like a case study or some other resource that describes how to

integrate the group-specific intercepts out of the likelihood

in Stan. Is that using the integrate_1d function? Thanks!

You could use integrate_1d to check your work, but in this case there is an analytic solution. The conditional variance of the outcome is equal to the between-group variance in the intercepts plus the variance of the measurement error.

Ah, I think I now get what you mean. That only really works for linear models, because of the properties of the normal distribution, right? I thought there was some secret trick to do this for arbitrary GLMs.

It only works analytically for normal-normal models. For other models with group-specific intercepts only, you could use the integrate_1d function to accomplish the same thing numerically.

Now I get what you meant, thanks! :)

If you get this working, be sure to post an example as I’d love to understand the approach.

1 Like

Well, I think the idea is essentially laid out for example in the GP chapter of the user guide, where you have the marginalized GP or the latent variable GP. The former only works with a normal likelihood, the latter is used for example for the Poisson case (generally GLM). I don’t have an implementation for “simple random effects” (varying intercepts), since I rarely work with normally distributed outcomes.

That’s likely only going to make the compute time worse as it will increase the model matrix that needs to be generated and then used for modelling (the default for k is 10).

I am afraid that @ucfagls is correct here, since as @ecologics pointed out and I have since been able to confirm in my own analyses, the brms implementation of GAMs does not support the automatic knot count selection as implemented in mgcv (which uses generalized cross validation for this purpose). To get around this, you could either manually reduce the number of knots (e.g., s(x1, k = 5)) if you believe this is appropriate, or you could follow @ecologics suggestion to use mgcv to estimate the ‘correct’ number of knots (see his reply to my previous question for details/caveats).

If you make any progress on this issue I’d be most interested in what worked - I am currently fitting approximately (20 x ) 100 GAMs on about 5k data points each, largely because the larger models are computationally intractable (at least for me at my current level of expertise). Even then fitting the 100 models takes 4 days on 4 physical (8 virtual) cores…

2 Likes

There are some misunderstandings here.

mgcv doesn’t do knot selection or counting, whether you use GCV smoothness selection or REML/ML. If you don’t specify k, the default of -1 means the basis mgcv creates will take the defaults for whatever smoother you request. For s(x), this will be a thin plate regression spline (TPRS) with k = 10, but you end up with, by default, 9 basis functions k-1 because of the identifiably constraints that are applied. You’ll get a different number of basis functions depending on the options passed to s() or te(), such as bs for the smoother type, m which controls things like the order, etc.

In the TPRS default, Simon developed a low-rank version which preserves many of the good properties of thin plate splines without needing to have one basis function per unique data/covariate value. What Simon does is create the full basis, subject the basis to an eigen decomposition, and then selects the k (or k-1 for identifiability constraints) first eigenvectors, such that the new basis is of the requested dimension whilst preserving as much of the full-rank information in the basis as possible.

With other bases, like the CRS (bs = ‘cr’), k controls the number of knots, whose locations are given by the boundaries of the covariate data and evenly in between, unless specified by the user.

What mgcv does do is penalize the coefficients for the basis functions via a wiggliness penalty. So you might start with 9 TPRS basis functions (given the defaults for s() with a 2nd order penalty), but the smoothness parameter controls how much of the wiggliness in the basis expansion you end up using, and the EDF of the resulting smoother will typically be lower than the maximum possible (9 by default). The smoothness parameter (which controls how much the penalty applies) is what is determined by GCV or REML/ML, but note you should not use GCV by default for various reasons, including the under smoothing that it is prone to.

In summary, you do have to choose the dimensionality of the basis expansion you want to use. mgcv will apply identifiability constraints so you may end up with fewer than the k=10 basis function by default. A smoothness penalty penalizes the coefficients for the basis functions so the resulting smoother may use fewer EDFs than implied by the number of basis functions.

My point above was that setting k = 20 would result in the model using 19 basis functions for that smoother, whereas the code the OP showed would have resulted in only 9 basis functions being used.

I didn’t mention this earlier, as the OP was complaining about speeding up fitting and I assumed they actually meant sampling. However, setting up the TPRS basis can take a significant amount of time with large n, because the creating the all the basis functions for the full TPRS basis and then eigen decomposing it can be costly. In such situations, you lose little by setting bs = ‘cr’ for example in the s() call.

So far I haven’t even mentioned brms. As brms uses mgcv::smoothCon() to set up the basis for each s(), everything I say above applies to GAMs fitted using brms. Even the it about smoothness parameters, which are proportional to the inverse of the variance parameter of the random effect representation of the smoother that is how brms fits the models. As such, brms fits exactly the same model as mgcv::gam(), except for the priors on parameters (implied in the case of mgcv). brms does do smoothness selection because it fits a penalized spline model - the variance parameter(s) associated with a smooth are related to the smoothness parameter(s) that mgcv selects. mgcv just happens to be ridiculously fast because of the work that Simon and other have put into developing the bespoke algorithms that are implemented in mgcv. GAMs are complex models and they are relatively slow to fit using MCMC.

One thing you can do to help in model fitting, if mgcv can fit the model you want, is prototype the model using mgcv and work out how big you need k to be using the tools mgcv provides. Then switch the brms and fit the models with the values for k for each smooth that you identified as being required from the mgcv fits. Or, you could just think about how much wiggliness (in terms of degrees of freedom) you expect each smooth relationship to possess, and set k to that number of degrees of freedom plus a bit; you want to make sure the basis expansion either contains the true function or a close approximation to it, so if you expect a ~20 df smooth relationship, you might want to set k = 30 to that the basis expansion hopefully contains a function in the span of the basis that either is the true function or closely approximates it.

9 Likes

And rstanarm, so I’m glad you provided so much information as to how mgcv works.

1 Like

Wow, thanks a lot for that detailed explanation! My apologies for spreading misinformation due to my misunderstanding of the process :-/

Among other things, your explanation has made me realize that I have not fully understood the differences in between gam, bam and gamm. In particular, I don’t understand what this sentence from the documentation of gamm (to which brms refers) means

It is assumed that the random effects and correlation structures are employed primarily to model residual correlation in the data and that the prime interest is in inference about the terms in the fixed effects model formula including the smooths.

in the context of this sentence (also from the gamm documentation):

Smooths are specified as in a call to gam as part of the fixed effects model formula, but the wiggly components of the smooth are treated as random effects.

If I am interested in modeling/understanding the shape of the smooth, is it then inappropriate to use gamm and therefore brm?

Regarding the OP’s question, do I understand correctly that with large datasets it may be worth trying whether a different basis (e.g., cr) in combination with appropriate choice of k leads to faster fitting with brm, but that for bs = 'tp' changing k should have little or no appreciable influence on fitting time since the full TPRS basis will have to be set up regardless of the choice of k?

In a mgcv::gamm() model, fitting is doe via nlme::lme() (possibly via MASS:glmmPQL() if fitting a non Gaussian family). In that model we might have three places where formulas are used:

  1. a model formula where you put the response and covariates,
  2. a random effects formula where you specify pure random effects via argument random
  3. a formula that describes clustering or ordering of samples if one is modelling the correlation matrix via argument correlation

The sentence you quote refers to the final two formulas and I think perhaps it stresses that the relative weight in terms of inference is on the smooth and any parametric effects in the model formula and that any parameters estimated for pure random effect terms or correlation terms are of secondary importance.

By “random effects” in this sentence Simon is, I believe, trying to create a distinction between pure random effects of the sort that you might estimate in lmer or brms in a typical GLMM setting (random slopes, intercepts), and the smooth-related random effects that are introduced as a result of the mixed-model representation of a penalised spline.

The second quote is just explaining how smooths get decomposed in the mixed model representation; so there’ll be extra random effects in the fitted_model$lme representation/component that relate to the wiggly bits of each smooth as well as any pure random effects that you specify in random. Because of the way these have to be represented in lme(), these random effect terms are quite complex if you look at the summary(fitted_model$lme) for example.

In summary, what Simon means, I think, by the first quote is that when using gamm(), you should really be primarily interested in the estimated smooth functions and any other parametric effects specified in the model formula, and not so much interested in the pure random effects or correlation parameters. The later are there to mop up any unmodelled structure to help with the i.i.d. assumptions.

gamm() is going to be quite approximate, especially if you are fitting a non-Gaussian response distribution as it is using a penalised quasi-likelihood algorithm for GLMMs. gam() shouldn’t have these problems as it is using very different algorithms, but isn’t as efficient at representing pure random effects (e.g. via s(xx, bs = 're')) as lme() is, so doesn’t scale well if you have lots of levels in your grouping variables. gamm4::gamm4() estimates GAMMs via lme4::glmer() and hence has better support for non-Gaussian responses and fits models using better algorithms than gamm(), but not as wide as gam(). Also you can only use t2() for tensor products. brms represents the model in the same way that gamm4() represents it, but estimates the model using HMC via Stan.

It is entirely appropriate therefore to fit GAMMs via brm(). The differences between functions are in terms of implementations and fitting algorithms.

I cringed every time I wrote “pure random effects” above; there’s no difference between a random effect introduced to model subject-specific mean responses and ones introduced to model a smooth. But seeing as calling them “random effects” potentially causes confusion I used “pure” to describe the non-smooth-related random effects. Apologies if this terminology offends anyone :-)

I don’t think using bs = 'cr' will result in faster sampling in large n problems, but it will result in faster set-up of the model as the cost of creating the basis is much less with the CRS basis than with the low-rank TPRS basis that is the default.

Changing k will affect sampling times, regardless of the basis you use, as increasing k will give you more coefficients whose distributions you need to sample from. It will also affect, but to a lesser extent, basis set-up for bs = 'tp', because mgcv doesn’t find all the eigenvalues and eigenvectors of the full TPRS basis. Instead it uses an algorithm that finds the first k eigenvalues and their eigenvectors. Where k relates to k. So, yes mgcv creates all the basis functions for the full TPRS (or some subset of 2000 unique knots for n_u > 2000, where n_u is the number of unique covariate values involved in the smooth), but it doesn’t need to do the full eigen decomposition. That said, for very large n, I’ve had models that I’ve fitted using a CRS basis with bam() that have converged in less time than it took to just create the TPRS basis for the same model.

This is why I think switching to bs = 'cr' might help in the larger n setting because basis set-up will be far quicker, but once the basis is created, sampling is going to be roughly the same whether you use CRS or TPRS.

4 Likes

Thanks all for contributing! In mgcv documentation it is suggested to change smoother to bs = “cr” in case of large data sets (should produce more or less the same results as bs = “tp”). I have no strong preferences to use Bayesian estimation here, but I also have a multimembership structure in my data, so GAM with brms seems to be my only option here. I have been experimenting with a small sample (n=5000) from my original data and bs = “cr” does seem to speed up things. But I also get warning about maximum treedepth and increasing “(max_treedepth = 15)” again slows things considerably (compared to “(max_treedepth = 12)”).

Thanks again for the clarification. It was indeed unclear to me whether the documentation did or did not intend to distinguish between two kinds of random effects.

Regarding speeding up fitting, my problem indeed appears to be with the sampling (including divergent transitions, treedepth warnings, etc) rather than slow model setup. Given your advice that suggests to me that I mainly need to look at my choice of k, as well as priors and model specification (including perhaps respecifying the model in Stan directly). It’s all a work in progress…