Theoretical lowest runtime – hierarchical linear model

Hi,

I have a hierarchical linear model which takes a rather long time to fit properly. So far, I’ve managed to reduce the runtime to 21-26 hours (while maintaining convergence indicators), but recently I’ve hit a wall in terms of speed improvement. This is after a few months of testing different samples of the data and different code enhancements (mainly from the Stan Reference Manual and User Guide).

I’d love to hear everyone’s ideas on:

  • In theory, what is the lowest runtime I can expect to achieve for this model, using existing Stan software and assuming I have unlimited compute resources?
  • What other enhancements could I make to my code to speed up the model?
  • What sort of compute setup would make this model faster?

This is a textbook example of simple hierarchical design. I am trying to predict a normal continuous outcome, with two sets of random effects and a few fixed effects. Each set of random effects corresponds to a natural grouping level in my data (technically, the groupings are fully nested, although for simplicity, I model them as crossed effects). For each of the random effects, I am modeling a set of intercepts and a set of slopes (partially-pooled, NCP with a Cholesky factorization).

Here is the Stan code, with a description of each input:

code.stan (4.1 KB)

I fit the model as follows:

model_7 <- rstan::stan(file=‘code.stan’,
data=data_final_5_stan,
seed=seed,
chains=4,
iter=4000,
warmup=3000,
control=list(adapt_delta=0.999999999,
max_treedepth=13))

Convergence diagnostics look OK to me but could be better. Highest Rhat is 1.07 and the lowest n_eff is 53. No divergent iterations and the posterior predictive check looks good. However, adapt_delta needs to be very high in order to keep Rhat within bounds.

I did some preprocessing before sending the data to Stan, so the magnitude of my predictors should not be an issue. The fixed effects are centered, scaled, and transformed to a normal shape. The target variable was logged to reduce skew and has a mean of 10.6 and a standard deviation of 0.95.

There is a lot of computation in the generated quantities block (two sets of predictions plus log_lik), but my understanding is that this block does not add much overhead compared to the model block.

I suspect what is consuming the most time here is the shape of my data, particularly the random effects. My first grouping variable has 2600+ unique levels of all different sizes. Unfortunately, the sizes of the groups are very right-skewed (majority <25 but up to 650 observations in one group). Since a different amount of pooling is happening in each group, I wonder if this is driving up the runtime.

My machine is a Windows server with 256 GB of RAM and a 32-core CPU E5. Rstan was set up with the recommended options:

options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

The content of Makevars.win is:

CXX14FLAGS=-O3 -march=native -mtune=native
CXX11FLAGS=-O3 -march=native -mtune=native

The output of devtools::session_info(“rstan”) is:

- Session info -------------------------------------------------------------------------------
 setting  value                          
 version  R version 3.5.2 (2018-12-20)   
 os       Windows Server 2008 R2 x64 SP 1
 system   x86_64, mingw32                
 ui       RStudio                        
 language (EN)                           
 collate  English_United States.1252     
 ctype    English_United States.1252     
 tz       America/Denver                 
 date     2019-07-09                     
 
- Packages -----------------------------------------------------------------------------------
 package      * version   date       lib source        
 assertthat     0.2.0     2017-04-11 [1] CRAN (R 3.5.1)
 backports      1.1.3     2018-12-14 [1] CRAN (R 3.5.2)
 BH             1.69.0-1  2019-01-07 [1] CRAN (R 3.5.2)
 callr          3.1.1     2018-12-21 [1] CRAN (R 3.5.2)
 cli            1.0.1     2018-09-25 [1] CRAN (R 3.5.2)
 colorspace     1.4-0     2019-01-13 [1] CRAN (R 3.5.2)
 crayon         1.3.4     2017-09-16 [1] CRAN (R 3.5.1)
 desc           1.2.0     2018-05-01 [1] CRAN (R 3.5.2)
 digest         0.6.18    2018-10-10 [1] CRAN (R 3.5.2)
 fansi          0.4.0     2018-10-05 [1] CRAN (R 3.5.2)
 ggplot2      * 3.1.0     2018-10-25 [1] CRAN (R 3.5.2)
 glue           1.3.0     2018-07-17 [1] CRAN (R 3.5.1)
 gridExtra      2.3       2017-09-09 [1] CRAN (R 3.5.1)
 gtable         0.2.0     2016-02-26 [1] CRAN (R 3.5.1)
 inline         0.3.15    2018-05-18 [1] CRAN (R 3.5.1)
 labeling       0.3       2014-08-23 [1] CRAN (R 3.5.0)
 lattice      * 0.20-38   2018-11-04 [2] CRAN (R 3.5.2)
 lazyeval       0.2.1     2017-10-29 [1] CRAN (R 3.5.1)
 loo          * 2.0.0     2018-04-11 [1] CRAN (R 3.5.1)
 magrittr       1.5       2014-11-22 [1] CRAN (R 3.5.1)
 MASS           7.3-51.1  2018-11-01 [2] CRAN (R 3.5.2)
 Matrix         1.2-15    2018-11-01 [2] CRAN (R 3.5.2)
 matrixStats    0.54.0    2018-07-23 [1] CRAN (R 3.5.1)
 mgcv           1.8-26    2018-11-21 [2] CRAN (R 3.5.2)
 munsell        0.5.0     2018-06-12 [1] CRAN (R 3.5.1)
 nlme           3.1-137   2018-04-07 [2] CRAN (R 3.5.2)
 pillar         1.3.1     2018-12-15 [1] CRAN (R 3.5.2)
 pkgbuild       1.0.2     2018-10-16 [1] CRAN (R 3.5.2)
 pkgconfig      2.0.2     2018-08-16 [1] CRAN (R 3.5.1)
 plyr           1.8.4     2016-06-08 [1] CRAN (R 3.5.1)
 prettyunits    1.0.2     2015-07-13 [1] CRAN (R 3.5.1)
 processx       3.2.1     2018-12-05 [1] CRAN (R 3.5.2)
 ps             1.3.0     2018-12-21 [1] CRAN (R 3.5.2)
 R6             2.3.0     2018-10-04 [1] CRAN (R 3.5.2)
 RColorBrewer   1.1-2     2014-12-07 [1] CRAN (R 3.5.0)
 Rcpp           1.0.0     2018-11-07 [1] CRAN (R 3.5.2)
 RcppEigen      0.3.3.5.0 2018-11-24 [1] CRAN (R 3.5.2)
 reshape2       1.4.3     2017-12-11 [1] CRAN (R 3.5.1)
 rlang          0.3.1     2019-01-08 [1] CRAN (R 3.5.2)
 rprojroot      1.3-2     2018-01-03 [1] CRAN (R 3.5.1)
 rstan        * 2.18.2    2018-11-07 [1] CRAN (R 3.5.2)
 scales       * 1.0.0     2018-08-09 [1] CRAN (R 3.5.1)
 StanHeaders  * 2.18.1    2019-01-28 [1] CRAN (R 3.5.2)
 stringi        1.2.4     2018-07-20 [1] CRAN (R 3.5.1)
 stringr      * 1.3.1     2018-05-10 [1] CRAN (R 3.5.1)
 tibble       * 2.0.1     2019-01-12 [1] CRAN (R 3.5.2)
 utf8           1.1.4     2018-05-24 [1] CRAN (R 3.5.1)
 viridisLite    0.3.0     2018-02-01 [1] CRAN (R 3.5.1)
 withr          2.1.2     2018-03-15 [1] CRAN (R 3.5.1)
 
[1] C:/Users/dcesar/Documents/R/win-library/3.5
[2] C:/Program Files/R/R-3.5.2/library

Any suggestions for improvement welcome. Thanks in advance.

1 Like

Hi,
some brief observations:

This is suspicious, if you need this high adapt_delta or max_treedepth to achieve convergence, something is probably amiss with your model. Rhat of 1.07 and n_eff of 53 is not great either, I suspect there is some problem with the model and/or data.

I don’t immediately see a big problem with your model, except for the slightly weird a ~ normal(11, 1); but that should match your data, so not sure it is an issue.

It is also possible that the model is not well identified by the data you have (althouh it may be identified in principle). I’ve written about some ways to recognize this is happening, but for high-dimensional problems this is tricky.

What you may want to try is to fit the same dataset with the brms package - your model seems well within its capabilities and brms is generally well tested and optimized. If the problem persists with brms, it is likely at least partly a mismatch between model and data, if brms works well, it would be worth investigating where your code differs from brms.

While this is generally true, it might consume additional memory and if memory runs out, performance is hit badly, could you check that you are not running out of memory while fitting?

I think this is unlikely the problem. The main factors driving runtime are in my experience (very roughly ordered):

  1. Issues with the model or mismatch between model and data
  2. Complexity (vaguely defined) of the model (e.g. models involving solving ODEs or algebraic equations have high complexity)
  3. The number of parameters (for you mostly the number of levels in all grouping factors combined)
  4. The amount of data

In a small test I did elsewhere sampling from 100 000 i.i.d. normal, which is probably the easiest model you can think of, takes 40 minutes on my PC. A simple hierarchical model with 25 000 parameters took me about 1.5 hours, and the same model with 100 000 parameters didn’t complete in 16 hours, so those are some empirical lower bounds on runtime.

Hope that helps and best of luck!

2 Likes

Hi Martin,

Thanks for the detailed feedback. The runtime benchmarking is helpful for me to set expectations. It sounds like for a model of my size a day or so is reasonable.

I will work on a brms translation of this model to see if I can spot a difference. I’ll also do a sanity-check to make sure running out of memory is not the issue. This is a shared machine so it’s possible that it maxed out on some of my recent models; although I doubt that that’s the whole story. However, I’d still like to understand why I can’t seem to get better convergence on some of the parameters.

I’m using a ~ normal(11,1); because it matches the empirical distribution of my outcome variable (rounded to a nice whole number). Another alternative here would be to center and scale the outcome, so I can remove a altogether. I don’t see this too often in practice, but I’m wondering if it would help convergence in this situation, since a is always hard to fit.

My intercept a usually has one of the highest Rhats and the worst-looking traceplot out of all my parameters. When I look at the traceplot for a, it doesn’t have any trouble finding the mean of 11 once warmup is started, but then the traces slowly move around the posterior:

I get this behavior with some of the group-level parameters, too, (beta_grouping_2, z_N_grouping_2, sigma_grouping_2[1]). This makes me think maybe there is a non-identifiability issue, like you suggest, that is affecting several groups in my data. The behavior here looks a lot like what you describe in the section of your blog post, “A hopelessly non-identified GP model.” If this is the case for my model, how would I diagnose it? In the blog you suggest putting very tight priors on the ‘true’ parameters to try to figure out what amount of prior info is needed to make them stable. How could I set up a test like this for my model (what would the ‘true’ values be)?

Thanks.

That’s easy (to say - not necessarily to do :-)), you simulate your data exactly matching your model (i.e. draw hyperparams from priors, draw params, draw observed data). Then you know the true values, because you simulated them.

This is IMHO unlikely to happen in a linear model. Provided you don’t have a bug in your code (which is always a possibility, indexing errors are a pain :-( ), I would still mostly suspect some linear relationship in the posterior - how do bivariate (a.k.a. pairs) plots look for the intercept versus the other problematic variables?

Also, when you have a simulator, you can simulate some tiny dataset and then you can easily plot everything vs. everything and diagnose quickly - a simulator is usually well worth the effort.

Also maybe the data are at fault - maybe for some of the groupings X doesn’t really vary making it indistinguishable from the varying intercepts? Many options in such a high dimension :-(

Hi Martin,

Thanks for all the suggestions. I’ve done some brainstorming based on what you said and I’ve come up with a few more ideas for why this model has a particularly hard time converging. At this point, I’m less focused on reducing runtime, and more focused on improving the convergence diagnostics (although the two go hand-in-hand).

I replicated my model in brms, as you suggested, and ran it on a representative sample of my data. The convergence diagnostics came up the same as my Stan model. I’m beginning to believe that there is no error in my Stan code - there’s either a problem with my dataset, or a problem with the questions I’m asking about my dataset.

My data have several characteristics which are less-than-ideal, but I’m not sure what is the severity of each issue. To try to understand which characteristics are most problematic, I’m running tests on different samples of my data. However, it would be better if I could understand the theory behind why some datasets don’t work well in my tests, rather than working through many different samples one-by-one. Perhaps you can spot something:

  • I now have 3029 distinct levels of Grouping 1 and 402 levels of Grouping 2. This works out to 10 observations per group, on average. Maybe the sheer number of levels (and the low N per P) is affecting my model fit.
  • There is a lot of variation in the sizes of the groups. The smallest groups have only one observation, while the largest has 1166. I agree that the variation itself is not an issue, but the very small groups might be.
  • Grouping 1 is fully nested within Grouping 2. In less than 1% of my dataset, Grouping 2 only includes one level of Grouping 1, so they are essentially redundant.
  • 68% of the Grouping 1 levels and 34% of the Grouping 2 levels have no variance along the first predictor. This works out to 17% of my data, for which it doesn’t make sense to estimate a slope (in theory). Because this is such a large portion of my dataset, I’m reluctant to just throw these observations away…but…the variable slopes between Grouping 1 and the first predictor form my hypothesis, so I can’t remove them.

Here are some of my observations so far:

  • You were right about the differing group sizes - that’s not the issue here. I ran a few tests on different samples of my data (sample with large groups, sample with small groups, and mixed sample), and I found that convergence doesn’t seem to suffer when the dataset contains different-sized groups. However, the average group size does seem to matter for convergence (the fewer observations per group, the harder it is for my model to fit).

  • The higher-level grouping (Grouping 2) plays some role in model stability. I tried removing these effects, and it made convergence worse. I don’t know why the Grouping 1 effects aren’t sufficient. Technically, I don’t need the Grouping 2 effects for my hypothesis.

What is my best bet?

  • I could remove the Grouping 2 slopes (but keep the intercepts). Hopefully this would help address the issue of having no/little variance along the first predictor in many groups. At least, I won’t force the model to come up with two separate slopes for these groups.
  • I could be more flexible about which variable effects I include in the model. For example, I could add/substitute a less granular grouping for Grouping 2. I could also break out my existing groupings into more granular components. For example, Grouping 1 can be re-expressed as A + B+ A:B. I don’t have much intuition about which designs would be easiest to fit, without testing each one.
  • I could remove groups where there is no variance in first predictor. You mentioned that in these groups, the intercept would be indistinguishable from the slope. In a frequentist model, these effects would be impossible to estimate, but in a Bayesian hierarchical model, wouldn’t we get some reasonable estimate via partial pooling? How many of these problematic groups do you think my model can reasonably handle?
  • Lastly, I could remove some of the small groups. About 4% of my dataset consists of a single-observation group, so it wouldn’t be a huge sacrifice to remove them. My logic for including them is similar to the above - even though all their intercepts and slopes can’t be estimated in theory, couldn’t I get an estimate via partial pooling? Is there a minimum group size that is necessary in order to get reliable estimates in a Bayesian model?

Thanks for your critical thinking and observations…they help a lot.

Diana

Did you have a chance to check the memory usage? If you’re not hitting your max, then you should consider using more chains and cores, as I understand you’re using a 32 core machine but only running 4 chains. I see you have a relatively high iteration number, is that because you wanted more post-warmup samples, or did you find you truly needed 2e3 iterations for warmup? If the former, you can drop the iteration count when you bump up the chain count. If the latter, you can separately specify the number of warmup iterations to maintain the long warmup on each chain, then only grab a few samples on each of many chains to speed things up a bit. You also might check out using within-chain parallelism, in which case you’d want to go back to fewer chains but allocating a bunch of cores per chain. Which I think is what’s covered by the Map-Reduce section of the manual?

Oh, and I see you use the Cauchy distribution a bunch; check out the “Reparameterizing the Cauchy” section of the Reparameterization section of the manual to speed that up a bit.

2 Likes

First - sorry to hear you are still fighting with the model. Hope we can get this sorted out soon :-)

Second, could you please post the brms formula you are working with and a sample of your data? That would probably help us to help you faster.

This is weird - usually removing parameters of the model improves convergence. I suspect this indicates that there is another part of the model that is problematic.

Depending on the exact way your model is set up, this might be an issue. A quick, wild guess: Assuming your predictor is called x, one way to include this data in your model but avoid estimating the slope would be to make a new predictor x_2 which equals x for groups that have some variance in x and 0 for the groups without variance. You would then estimate the slope for (x_2 | Group). Does that make sense?

Based on the observations above, I still think it is likely that the problem is actually in another part of the model. So maybe none of the above? I’ll however wait for your brms formula and sample data before attempting further judgement.

You should also try to use pairs or ShinyStan to inspect bivariate plots, especially between slopes and intercepts (global or within the same group) and look for suspicious patterns (as in the non-identifiability post).

Best of luck with modelling!

Unfortunately the manual does not mention that the trick only applies to the full unconstrained Cauchy. Here the Cauchys are actually half-Cauchys constrained to positive values. The implicit transform that Stan uses to enforce the lower=0 bound already thins the tail in much the same way tangent-transform would. And heavy tail for the variance doesn’t seem reasonable anyway. It just implies prior expectation that your data might well be all noise and no signal. I’d say use exponential(1) prior instead.

2 Likes

Ah! News to me, thanks for the heads up! Saves me some typing and will make my models more readable.

I tend to agree. I also find the peaked-at zero priors for variances unreasonable and prefer a Weibull, where you can express that noise can be small, but never zero.

My machine is 512GB, so it’s pretty hard for me to max it out on this problem (especially when I do a test on a smaller sample of my dataset - these are only N<3000 records). Also, the fact that I get similar convergence diagnostics on the small samples and the full dataset makes me think that it’s not an issue of maxing out memory.

What would be the advantage to using more chains, other than getting more posterior samples? Would running more chains help to reduce the Rhat (because Rhat measures the agreement between chains)?

As for the warmup iterations, the reason I’ve specified 3000 is to give the chains a better chance at convergence (at the cost of runtime). For prediction/inference, having 1000 post-warmup iterations is more than enough precision for me.

Thanks for the resources on MPI and Map-Reduce. At this point, I think figuring out how to improve convergence is more of a priority for this model, so I’ll save these tools for down the road. I still find it strange why it’s tough to get Rhat < 1.01 and ESS > 400 for a simple linear model like this one.

To provide some context, I recently ran a similar model design on the same dataset (or very similar data - I get monthly updates from the same source). This model only included a single set of variable intercepts (Grouping 1), and removed both sets of variable slopes. I also included the same set of fixed effects as the one above.

This model ran in 32 minutes (as compared to the 20+ hours above) and got a max Rhat of 1.009 (compared to 1.07 above). I don’t see why adding variable slopes would result in such a stark difference, but it sure is strange.

1 Like

I’m a little bit out of my depth with the Reparameterization section, so I’ll trust your judgement here.

I’m not sure I follow…how would that prevent the model from estimating a slope coefficient for the groups which have no variance in x? Even if all the x’s in that group are zero, wouldn’t the model still need to come up with some value for the coefficient (because I’ve specified a parameter for the set of slopes, indexed from 1 to 3029)?

Actually, predictor x is binary, so this is already the case for some groups. For example:

Group Predictor X
Group 1 0 Group with variance in x
Group 1 1
Group 1 0
Group 1 1
Group 2 0 Group with a single observation
Group 3 0 Group with all zeros
Group 3 0
Group 4 1 Group with all ones
Group 4 1

I will work on a data sample for you and some brms code. Since my dataset is confidential, I have to anonymize and relabel some things first…may take a while.

Thanks

See Rhat and effective sample size experiments for Cauchy and half-Cauchy in https://avehtari.github.io/rhat_ess/rhat_ess.html#appendix_c:_more_experiments_with_the_cauchy_distribution
and you can see that difference is big.

2 Likes

Yes, the parameter will still be in the model, but my idea is that it would not be connected to data in any way, i.e. the posterior will exactly equal the prior for the parameter. This should be easy for the sampler and should not introduce any convergence issues. It would obviously be better to not include those parameters at all, but this trick let’s you do this with brms or your original Stan code. Caveat: I’ve never actually used this trick, it just occurred to me that it could possibly help you.

1 Like