Theoretical lowest runtime – hierarchical linear model

Hi,

I have a hierarchical linear model which takes a rather long time to fit properly. So far, I’ve managed to reduce the runtime to 21-26 hours (while maintaining convergence indicators), but recently I’ve hit a wall in terms of speed improvement. This is after a few months of testing different samples of the data and different code enhancements (mainly from the Stan Reference Manual and User Guide).

I’d love to hear everyone’s ideas on:

  • In theory, what is the lowest runtime I can expect to achieve for this model, using existing Stan software and assuming I have unlimited compute resources?
  • What other enhancements could I make to my code to speed up the model?
  • What sort of compute setup would make this model faster?

This is a textbook example of simple hierarchical design. I am trying to predict a normal continuous outcome, with two sets of random effects and a few fixed effects. Each set of random effects corresponds to a natural grouping level in my data (technically, the groupings are fully nested, although for simplicity, I model them as crossed effects). For each of the random effects, I am modeling a set of intercepts and a set of slopes (partially-pooled, NCP with a Cholesky factorization).

Here is the Stan code, with a description of each input:

code.stan (4.1 KB)

I fit the model as follows:

model_7 <- rstan::stan(file=‘code.stan’,
data=data_final_5_stan,
seed=seed,
chains=4,
iter=4000,
warmup=3000,
control=list(adapt_delta=0.999999999,
max_treedepth=13))

Convergence diagnostics look OK to me but could be better. Highest Rhat is 1.07 and the lowest n_eff is 53. No divergent iterations and the posterior predictive check looks good. However, adapt_delta needs to be very high in order to keep Rhat within bounds.

I did some preprocessing before sending the data to Stan, so the magnitude of my predictors should not be an issue. The fixed effects are centered, scaled, and transformed to a normal shape. The target variable was logged to reduce skew and has a mean of 10.6 and a standard deviation of 0.95.

There is a lot of computation in the generated quantities block (two sets of predictions plus log_lik), but my understanding is that this block does not add much overhead compared to the model block.

I suspect what is consuming the most time here is the shape of my data, particularly the random effects. My first grouping variable has 2600+ unique levels of all different sizes. Unfortunately, the sizes of the groups are very right-skewed (majority <25 but up to 650 observations in one group). Since a different amount of pooling is happening in each group, I wonder if this is driving up the runtime.

My machine is a Windows server with 256 GB of RAM and a 32-core CPU E5. Rstan was set up with the recommended options:

options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

The content of Makevars.win is:

CXX14FLAGS=-O3 -march=native -mtune=native
CXX11FLAGS=-O3 -march=native -mtune=native

The output of devtools::session_info(“rstan”) is:

- Session info -------------------------------------------------------------------------------
 setting  value                          
 version  R version 3.5.2 (2018-12-20)   
 os       Windows Server 2008 R2 x64 SP 1
 system   x86_64, mingw32                
 ui       RStudio                        
 language (EN)                           
 collate  English_United States.1252     
 ctype    English_United States.1252     
 tz       America/Denver                 
 date     2019-07-09                     
 
- Packages -----------------------------------------------------------------------------------
 package      * version   date       lib source        
 assertthat     0.2.0     2017-04-11 [1] CRAN (R 3.5.1)
 backports      1.1.3     2018-12-14 [1] CRAN (R 3.5.2)
 BH             1.69.0-1  2019-01-07 [1] CRAN (R 3.5.2)
 callr          3.1.1     2018-12-21 [1] CRAN (R 3.5.2)
 cli            1.0.1     2018-09-25 [1] CRAN (R 3.5.2)
 colorspace     1.4-0     2019-01-13 [1] CRAN (R 3.5.2)
 crayon         1.3.4     2017-09-16 [1] CRAN (R 3.5.1)
 desc           1.2.0     2018-05-01 [1] CRAN (R 3.5.2)
 digest         0.6.18    2018-10-10 [1] CRAN (R 3.5.2)
 fansi          0.4.0     2018-10-05 [1] CRAN (R 3.5.2)
 ggplot2      * 3.1.0     2018-10-25 [1] CRAN (R 3.5.2)
 glue           1.3.0     2018-07-17 [1] CRAN (R 3.5.1)
 gridExtra      2.3       2017-09-09 [1] CRAN (R 3.5.1)
 gtable         0.2.0     2016-02-26 [1] CRAN (R 3.5.1)
 inline         0.3.15    2018-05-18 [1] CRAN (R 3.5.1)
 labeling       0.3       2014-08-23 [1] CRAN (R 3.5.0)
 lattice      * 0.20-38   2018-11-04 [2] CRAN (R 3.5.2)
 lazyeval       0.2.1     2017-10-29 [1] CRAN (R 3.5.1)
 loo          * 2.0.0     2018-04-11 [1] CRAN (R 3.5.1)
 magrittr       1.5       2014-11-22 [1] CRAN (R 3.5.1)
 MASS           7.3-51.1  2018-11-01 [2] CRAN (R 3.5.2)
 Matrix         1.2-15    2018-11-01 [2] CRAN (R 3.5.2)
 matrixStats    0.54.0    2018-07-23 [1] CRAN (R 3.5.1)
 mgcv           1.8-26    2018-11-21 [2] CRAN (R 3.5.2)
 munsell        0.5.0     2018-06-12 [1] CRAN (R 3.5.1)
 nlme           3.1-137   2018-04-07 [2] CRAN (R 3.5.2)
 pillar         1.3.1     2018-12-15 [1] CRAN (R 3.5.2)
 pkgbuild       1.0.2     2018-10-16 [1] CRAN (R 3.5.2)
 pkgconfig      2.0.2     2018-08-16 [1] CRAN (R 3.5.1)
 plyr           1.8.4     2016-06-08 [1] CRAN (R 3.5.1)
 prettyunits    1.0.2     2015-07-13 [1] CRAN (R 3.5.1)
 processx       3.2.1     2018-12-05 [1] CRAN (R 3.5.2)
 ps             1.3.0     2018-12-21 [1] CRAN (R 3.5.2)
 R6             2.3.0     2018-10-04 [1] CRAN (R 3.5.2)
 RColorBrewer   1.1-2     2014-12-07 [1] CRAN (R 3.5.0)
 Rcpp           1.0.0     2018-11-07 [1] CRAN (R 3.5.2)
 RcppEigen      0.3.3.5.0 2018-11-24 [1] CRAN (R 3.5.2)
 reshape2       1.4.3     2017-12-11 [1] CRAN (R 3.5.1)
 rlang          0.3.1     2019-01-08 [1] CRAN (R 3.5.2)
 rprojroot      1.3-2     2018-01-03 [1] CRAN (R 3.5.1)
 rstan        * 2.18.2    2018-11-07 [1] CRAN (R 3.5.2)
 scales       * 1.0.0     2018-08-09 [1] CRAN (R 3.5.1)
 StanHeaders  * 2.18.1    2019-01-28 [1] CRAN (R 3.5.2)
 stringi        1.2.4     2018-07-20 [1] CRAN (R 3.5.1)
 stringr      * 1.3.1     2018-05-10 [1] CRAN (R 3.5.1)
 tibble       * 2.0.1     2019-01-12 [1] CRAN (R 3.5.2)
 utf8           1.1.4     2018-05-24 [1] CRAN (R 3.5.1)
 viridisLite    0.3.0     2018-02-01 [1] CRAN (R 3.5.1)
 withr          2.1.2     2018-03-15 [1] CRAN (R 3.5.1)
 
[1] C:/Users/dcesar/Documents/R/win-library/3.5
[2] C:/Program Files/R/R-3.5.2/library

Any suggestions for improvement welcome. Thanks in advance.

1 Like

Hi,
some brief observations:

This is suspicious, if you need this high adapt_delta or max_treedepth to achieve convergence, something is probably amiss with your model. Rhat of 1.07 and n_eff of 53 is not great either, I suspect there is some problem with the model and/or data.

I don’t immediately see a big problem with your model, except for the slightly weird a ~ normal(11, 1); but that should match your data, so not sure it is an issue.

It is also possible that the model is not well identified by the data you have (althouh it may be identified in principle). I’ve written about some ways to recognize this is happening, but for high-dimensional problems this is tricky.

What you may want to try is to fit the same dataset with the brms package - your model seems well within its capabilities and brms is generally well tested and optimized. If the problem persists with brms, it is likely at least partly a mismatch between model and data, if brms works well, it would be worth investigating where your code differs from brms.

While this is generally true, it might consume additional memory and if memory runs out, performance is hit badly, could you check that you are not running out of memory while fitting?

I think this is unlikely the problem. The main factors driving runtime are in my experience (very roughly ordered):

  1. Issues with the model or mismatch between model and data
  2. Complexity (vaguely defined) of the model (e.g. models involving solving ODEs or algebraic equations have high complexity)
  3. The number of parameters (for you mostly the number of levels in all grouping factors combined)
  4. The amount of data

In a small test I did elsewhere sampling from 100 000 i.i.d. normal, which is probably the easiest model you can think of, takes 40 minutes on my PC. A simple hierarchical model with 25 000 parameters took me about 1.5 hours, and the same model with 100 000 parameters didn’t complete in 16 hours, so those are some empirical lower bounds on runtime.

Hope that helps and best of luck!

Hi Martin,

Thanks for the detailed feedback. The runtime benchmarking is helpful for me to set expectations. It sounds like for a model of my size a day or so is reasonable.

I will work on a brms translation of this model to see if I can spot a difference. I’ll also do a sanity-check to make sure running out of memory is not the issue. This is a shared machine so it’s possible that it maxed out on some of my recent models; although I doubt that that’s the whole story. However, I’d still like to understand why I can’t seem to get better convergence on some of the parameters.

I’m using a ~ normal(11,1); because it matches the empirical distribution of my outcome variable (rounded to a nice whole number). Another alternative here would be to center and scale the outcome, so I can remove a altogether. I don’t see this too often in practice, but I’m wondering if it would help convergence in this situation, since a is always hard to fit.

My intercept a usually has one of the highest Rhats and the worst-looking traceplot out of all my parameters. When I look at the traceplot for a, it doesn’t have any trouble finding the mean of 11 once warmup is started, but then the traces slowly move around the posterior:

I get this behavior with some of the group-level parameters, too, (beta_grouping_2, z_N_grouping_2, sigma_grouping_2[1]). This makes me think maybe there is a non-identifiability issue, like you suggest, that is affecting several groups in my data. The behavior here looks a lot like what you describe in the section of your blog post, “A hopelessly non-identified GP model.” If this is the case for my model, how would I diagnose it? In the blog you suggest putting very tight priors on the ‘true’ parameters to try to figure out what amount of prior info is needed to make them stable. How could I set up a test like this for my model (what would the ‘true’ values be)?

Thanks.

That’s easy (to say - not necessarily to do :-)), you simulate your data exactly matching your model (i.e. draw hyperparams from priors, draw params, draw observed data). Then you know the true values, because you simulated them.

This is IMHO unlikely to happen in a linear model. Provided you don’t have a bug in your code (which is always a possibility, indexing errors are a pain :-( ), I would still mostly suspect some linear relationship in the posterior - how do bivariate (a.k.a. pairs) plots look for the intercept versus the other problematic variables?

Also, when you have a simulator, you can simulate some tiny dataset and then you can easily plot everything vs. everything and diagnose quickly - a simulator is usually well worth the effort.

Also maybe the data are at fault - maybe for some of the groupings X doesn’t really vary making it indistinguishable from the varying intercepts? Many options in such a high dimension :-(