Gaussian hierarchical factor spline model: increasing efficiency

I have fit a longitudinal Gaussian model in brms on a data set of about 400,000 individual observations and with a couple of hierarchical/penalized terms. See the formula below.

brmsformula(
  y ~ s(time, k = 5) + s(time, state, k = 5, bs = "fs") + subgroup + (1 | subgroup:state),
  decomp = "QR"
)

So we have a

  • a global average time trend
  • state-specific average time trends, shrunken towards the global trend
  • state varying intercepts
  • global subgroup effects
  • subgroup varying intercepts within states.

There are about 70 states and within each state, there are exactly 6 subgroups nested. Time data is discrete in 30 equally-spaced steps, but obviously modeled continuously. All continuous variables are scaled to mean 0 and SD 1.

The model does what it should do and converged well, but fitting with configuration …

  chains = 2,
  iter = 5e3,
  warmup = 1e3,
  control = list(adapt_delta = 0.825),
  prior = c(set_prior("exponential(1)", class = "sigma"),
                set_prior("exponential(1)", class = "sd"),
                set_prior("exponential(1)", class = "sds"),
                set_prior("normal(0,2)", class = "Intercept"),
                set_prior("normal(0,2)", class = "b")
            ),
  backend = "cmdstanr",
  normalize = FALSE,
  threads = threading(16)

… took about 5 days on 32 CPUs.

Are there obvious ways to speed that up, e.g., by setting up the model more cleverly?

An initial idea I had:

Exploiting sufficient statistics. Since there is only identity link and Gaussian error, is this a way to go here - even though other aspects of the model are rather complex? Are there examples of how to set up such a model (I guess, only based on means and SE by strata?) in brms?

There are a few things that you might be able to consider for improving efficiency. First, the fs basis implicitly contains an intercept for each state, so you have some redundancy wrt your nested varying intercepts. Simon Wood fairly recently added a new factor smooth basis to {mgcv}, the sz basis, to make this sort of model easier to set up. These remove the intercepts and are useful when you want to allow nonlinear effects to be regularized toward a shared “global” effect. But I believe @ucfagls has mentioned that he’s heard of problems with this basis in {brms}, so he may want to chime in. Perhaps it is worth fitting this model to a small subset of your data to see if it behaves well or not.

Stan code is also more efficient when using std_normal(), so if you could perform any transformations that allow you to use this as a prior then that may give you some speedups.

Also, I’m not familiar with using the QR decomposition for smooths or random effects, I thought it was typically useful for parametric effects. Perhaps it isn’t necessary here and just adds computation? But I could be very wrong about that so feel free to ignore.

1 Like

Because of the way the sz basis is implemented you can’t use it in brms, gamm4 etc as it is not supported by smooth2random(), which is what these package use to convert the spline into fixed and random effect forms. The reason for this is that sz is implemented as a tensor product and standard tensor products IIRC as those particular smooths don’t have separable penalties required for them to be represented as random effects in nlme, lme4, brms, etc.

1 Like

Ah ok, I must have been mistaken in that case. By chance do you know if id is supported in {brms} smooths? If so this could offer a sensible workaround, i.e.:

y ~ s(time, k = 5) + s(time, by = state, k = 5, id = 1) + subgroup + (1 | subgroup:state)

This should give zero-centred deviation smooths of time that don’t include intercepts for state and that enforce the state-varying smooths to share the same amount of regularization

1 Like

Unfortunately brms doesn’t seem to respect the id convention:

library("brms")
library("gratia")
su_eg4 <- data_sim("eg4", n = 400, dist = "normal", scale = 2, seed = 1)
m <- brm(y ~ fac + s(x2, by = fac, id = 1) + s(x0), 
      data = su_eg4 
    )
summary(m)
 Family: gaussian 
  Links: mu = identity; sigma = identity 
Formula: y ~ fac + s(x2, by = fac, id = 1) + s(x0) 
   Data: su_eg4 (Number of observations: 400) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Smoothing Spline Hyperparameters:
               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sds(sx2fac1_1)     3.35      1.72     1.17     7.57 1.00     2386     2723
sds(sx2fac2_1)     3.50      1.80     0.99     7.94 1.00     2083     1685
sds(sx2fac3_1)    15.45      4.51     8.85    26.44 1.00     1890     2454
sds(sx0_1)         1.08      1.08     0.03     4.00 1.00     1632     2166

Regression Coefficients:
           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept      1.25      0.17     0.91     1.59 1.00     6382     2755
fac2          -2.17      0.26    -2.69    -1.66 1.00     6469     2959
fac3           2.28      0.24     1.80     2.76 1.00     6291     3218
sx2:fac1_1    -3.23      6.71   -18.07     8.77 1.00     2692     2198
sx2:fac2_1    16.77      7.13     4.44    32.27 1.00     3096     2768
sx2:fac3_1    46.05     16.14    16.00    78.81 1.00     3327     2807
sx0_1          0.80      2.87    -4.23     8.21 1.00     2568     1493

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     2.05      0.07     1.91     2.20 1.00     7310     2721

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Warning message:
There were 2 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup

As it still seems that 3 inverse smoothing parameters / ranef SDs have been estimated for this quick example. Perhaps @paul.buerkner can comment on the possibility that brms might allow for a shared smoothing parameter across smooths via the id mechanism-

1 Like

Interesting. I didn’t even know this option existed. Can you open an issue on github so that I don’t forget implementing this at some point?

1 Like

Thanks so much for your responses! Indeed, going with by in the smooth term has advantages. Among others that I can use thin plate regression splines and include another hierarchical level for the spline (above state), which is actually informative in my case (as far as I see, this would not possible with the "fs" basis). The computational costs are heavy though due to the many smoothing parameters. I will attempt a run in the cloud.

Perhaps as a start, while you are continuing model expansion and exploration, you should fit this model in {mgcv} instead. There are ways you can accomplish the type of random effect nesting in gam(), despite the fact that it doesn’t use lmer() random effect syntax: r - How does random variable nesting in GAMs work (mgcv)? - Cross Validated. Obviously {brms} is far more flexible and ideal for including informed regularization, but just waiting that long just to inspect a model variant is a big ask

2 Likes

Will do!

With that much data, don’t use thin plate splines; the eigen decomposition needed just to create the basis expansions for the low rank TPRS smooths will be costly enough for very little gain. Use bs = "cr" except where you need something more exotic, like the "re" smooths, etc.

1 Like

Thank you will do that. I suppose that also much could be gained by stronger priors (more penalty) on the wiggliness. For the prior class sds, I am thinking of moving a bit more towards zero, e.g., by exponential(3/2).
Probably the best way is to simulate some data and see where the prior brings us?