Additive random effect models with large data sets

  • Operating System: CentOS
  • rstanarm Version: 2.15.3

Hi,

I am trying to fit a Stan model on a rather large data set with ~3 million observations. The model itself is relatively simple, with one smoothed variable, three binary variables, and one random effect,

Outcome ~ spline(age) + A + B + C + 1|Year

However, I’m quickly running out of memory when trying to run this model (the server has 1,1TB RAM, but thats not enough). I’ve tracked it down to the computations done in the mgcv::jagam function.

Is there any hope for running this model, or am I better off sub-sampling, say, 200.000 data points? Looking at the data, age is highly non-linear, hence the need for some smoothing.

Thanks in advance.

2 Likes

I recommend sub-sampling to find out what is the largest n you can run now, and then it is easier to discuss possibilities. Also run with cores=1 to save memory. It’s likely that it’s going to be very slow with the current rstanarm, and it’s possible that you don’t need MCMC for this simple model and this big data. gamm4 would be faster option, but then it’s using the same mgcv package to create basis functions for splines so it’s likely you run out of memory with that, too.

With the github version of rstanarm I can run 1 million observations and 100 predictors in my laptop with 32GB memory (about 2s per chain, and one minute to pass data around and do convergence diagnostics), and if there are not too many years, then that formula would correspond to the same order of magnitude of predictors. Installing from github is not enough as that speedier computation is not yet implemented for random effects, but you could build the design matrix outside of rstanarm. Then there would still be the problem of forming the regression spline basis functions, but you could form them with subsample of data and add to to your design matrix.

Woa, that sounds like some serious optimization in the rstanarm branch! Do those 1 million observations and 100 predictors also include a smoothing term?

I actually got it running with brms (each chain is using ~6GB of memory), but it sounds like it could be faster to run the model in rstanarm.

Not really. It’s a old Gaussian trick which has been for a long time in stan_lm. Recently @bgoodri added it also to stan_glm (conditional that trick is applicable) and I enabled it also for stan_gamm4, but not yet in case of random effects as they are handled a bit differently, but it would be possible to enable this trick for certain random effect models. The trick is to precompute X'X and during each leapfrog there compute only p dimensional multivariate normal density. If p is small then this is much faster than computing the linear predictor (n*p matrix times p vector) and n normal densities.

It’s just simulated data with 100 predictors, but that is then 100 columns in X. Smoothing terms add bunch of basis functions to matrix X_smooth and we can just combine these matrices. It’s likely that your spline is using 20-40 basis functions. The speedup would work already (in github version) for

Outcome ~ spline(age) + A + B + C

with stan_gamm4, but as I said supporting the random effects would require a bit more coding. You could add + D, and define D to be a design matrix corresponding to 1|Year.

Can you share some timing results?

And to confirm, I just run 1 million observations simulated data

fit <- stan_gamm4(y ~ s(x) + xn, data=fake, mean_PPD=FALSE,
                 refresh=100, seed=SEED, cores=1, control = list(max_treedepth=6))

each chain takes 2.5s, and in addition time to move data around and computation of rstanarm default diagnostics takes a long time because rstanarm spline implementation is using a lot of memory and rsession is using 35GB memory and I have only 32GB. With better spline basis function computation the other parts would be fast, too. I guess in my laptop I could handle half million observations for one spline.

I’m bit slow today, and realized now that it’s a bit more complicated due to priors, but still it is possible in principle to make this fast.

I don’t know how to do this. If you take the first year to be the baseline, then it is fine to include a year dummy variable for the rest in the design matrix. Or exclude the intercept. But if you include an intercept and 1 | Year, then all of the year variables get included and there is no longer a unique least squares solution to write the likelihood with. You could use a non-unique least squares solution, but I don’t think the same math goes through to do the likelihood.

My guess is that we could do something with the least squares solution that has the minimum L2 norm, but I haven’t worked out the details.

Thanks for reminding this. How much different it would be to do

Outcome ~ spline(age) + A + B + C + factor(Year)

if there is lot of data compared to the number of Years?

That is fine for stan_gamm4 (on GitHub) or for stan_lm (now) if you do the basis expansion yourself.

Maybe I should start a new thread, but I add here that the OLS trick gets quite slow when the number of columns in X (or X and X_smooth) is much larger than 100. For example, with p=1000, it seems it would be faster to switch to normal_id_glm_lpdf (current comparison is based bernoulli_logit_glm_lpmf`, but I can do also direct test)

With or without QR? The likelihood for stan_lm takes almost no time to evaluate but the treedepth can get larger with many predictors.

Without QR. Going from p=100 to p=1000 is 100x slower, ie, the computation cost is O(p^2) per iteration. When using hs prior, computation for 4 chains with 1000 post warmup draws goes from 4 minutes to 7 hours. normal_id_glm_lpdf would be useful also for stan_glm when n<p and, e.g. hs prior used.

I don’t think the normal_id_glm_lpdf function in Stan Math is quite general enough, but we should write C++ versions of all the likelihoods in rstanarm using adj_jac_apply .

Not general enough to make everything fast in stan_glm, but one bit more.

It’s not just likelihood, as glm functions include also part of the model. But yes, more speed would be great, and it adj_jac_apply should make it easier.

I added normal_id_glm_lpdf and tested stan_glm with p=1000, and with n=1001 (OLS trick is used) and with n=999 (normal_id_glm_lpdf) is used. normal_id_glm_lpdf is x2 faster. Same with p=500 and n=501/499. So it seems OLS trick would starts to be better when n>2*p.

Very interesting discussion. I’ll try to come back with some timing and a comparison to rstanarm when i’ts updated on the compute cluster. It seems like it will take at least a few days to run, though.

The model is running no problem using the mgcv::gam function. I would, however, still prefer a Bayesian estimation, both to do model validation (LOO/calibration) and have an uncertainty interval for the prediction. I think the memory problem is inherently tied to the use of mgcv::jagam.

Isnt the max_treedepth=6 cheating a bit? I.e., I would need to know apriori if six is enough post warm-up for a particular data set. What’s really killing the computation time, with BRMS at least, are the warm-up steps where it goes to the max tree depth.

I’ve grouped year in five-year intervals, so theres only 9 groups.

The reason to include year as a random effect is actually that I would like to do predictions independently of the year. Maybe there’s a better way of achieving this, but a random effect was the first solution that came into mind.

+1

Not really, There is not much difference in this case, it was there as I was doing repeated experiments and I’m at the same time collecting information on typical treedepths needed. Furthermore, important point is that exceeding max_treedepth doesn’t invalidate Markov chains, and it’s enough to check that rhat, ess_bulk and ess_tail are good. I think the current default in rstanarm is way too high. If ess_bulk and ess_tail are low then it’s useful to know if there are max_treedepth exceedences.

Yes, it seems it would be useful to start with much lower max_treedepth in warm-up adaptation phase and use better adaptation. While waiting these to be implemented, you can experiment with reducing max_treedepth.

I’m running it now using gamm4 and its actually running no problem. It’s using about 5gb of RAM.

I see. I tried running brms with max_treedepth=6, and that took about 12hours. Unfurtunately it did not convergence so now I’m trying again with max_treedepth=8.

I’m quite certain that even with max_treedepth = 15
Outcome ~ spline(age) + A + B + C + factor(Year)
in github version of rstanarm would be less than a few minutes if you can get it fit memory.

This is based on comparing performance with 1e6 and 1e5 without the OLS trick which corresponds to what brms is using.