Have you been using some of the latest features of Stan?

Note: This same question was posted on Twitter, just making a Discourse thread for those that prefer this channel.

For a few different reasons that are mostly grant & papers related we are interested if you, the users of Stan and brms, are using some of the performance-related features that were added in Stan & brms recently and if you are using them, if you can tell us a bit more on how and where are you using them (the model type, what was the gain, etc.).

We are primarily interested if you are using:

  • within-chain parallelization with the reduce_sum function in Stan or brms (threads argument to brm)
  • profiling in Stan
  • OpenCL backend in Stan or brms (opencl argument to brm)

If you are using any other feature not listed that was added just recently and you find that has greatly helped your work or research, also feel free to share.

9 Likes

Profiling helped me work out that models are faster when structured to use columns_dot_sum rather than rows_dot_sum (n.b. follow-up explorations on the odd error bar variability over on slack)

@jades can you comment on how much the use of reduce_sum sped up sampling of your hierarchical psychophysical model?

Edit: oh! And back in the spring profiling helped me work out a faster structure for using *_dot_product in hierarchical models.

Edit 2: and profiling helped me work out that Gaussian likelihoods (n.b. only with lots of IID observations) can be sped up via sufficient stats.

Edit 3: and I have a TODO of using profiling to isolate the benefits of my reduced-redundant-computation trick for hierarchical models; may get to that as I work up the notebook for my coming StanConnect talk.

7 Likes

I’ve done some benchmarking with the three level logistic regression model for repeated cross-sectional survey data (340,000 observations across 65 countries and 256 country-surveys) I’m working with right now. Running six chains in parallel on the CPU (Ryzen 9 5900X), the full model specification takes about four days to fit.

GPU-based computation via OpenCL with an Nvidia RTX 3080TI cuts that down to just under 25 hours which is better but still a long time, especially considering I ended up having to write a lengthy function to override the brms generated model code with a manually vectorized version, fit it via cmdstanr, then save the result to a brmsfit object so I could use the package’s various convenience functions because I’m lazy

Within-chain threading via brms using cmdstanr as a backend seems to be the winning ticket right now in terms of performance and ease of use. That approach fits the full model in just under seven hours and doesn’t require any of the additional post-processing work that OpenCL via cmdstanr does.

5 Likes

One issue that I think use some more discussion is tuning the parallelization in reduce_sum. More and more users are exploiting cloud resources and the common intuition based on all threads being on one machine, or at least being on one machine with few contested resources, often doesn’t translate particularly well. While there won’t be a general solution by any means I think it would help to set better expectations by discussing the circumstances when reduce_sum might actually offer practical speed ups.

2 Likes

I agree, I think more examples are needed for reduce_sum both on how to parallelize and how to tune. Also maybe temper expectations on using 16 or 32 cores with gradients that take a few milliseconds.

Awhile back me and @Adam_Haber made some benchmarks on a few different types of models and EC2 instances with some surprising results. May be finally time to dust that one off and post it somewhere.

4 Likes

Some general discussion of the possible overheads that obstruct linear scaling, as well as recommend steps to diagnose when certain overheads are dominating, would also be welcome. Including for example verifying that the reduce_sum calculation is actually the same as the serial calculation (this one bites a lot of people), comparing empirical performance vs number of shard size to see if there’s any positive or even negative scaling, etc.

Recently I had a client trying to parallelize over just a few cores and they still have trouble achieving any speed up despite an accommodating model structure. They had nowhere to turn when the reduce_sum was significantly slower than the serial code.

2 Likes

My coworker and I have had success speeding up our models with reduce sum on models that require multiple ODE solves.

The benchmarking has also been useful.

5 Likes

The brms vignette on reduce_sum should be a good starting point to evaluate reduce_sum scaling performance. I don’t see any other way than just trying it out and measure empirically…this approach makes a lot of sense for brms models where it’s a matter of turning a switch for the brm call.

In case Someone worries about reproducibility then the person should use the static version.

https://cran.microsoft.com/web/packages/brms/vignettes/brms_threading.html

3 Likes

At my company we have relied heavily on within-chain parallelization for scaling to larger datasets. I’ve been impressed with how easy this was to implement, especially in brms where it “just works” even with custom likelihoods.

Profiling has proved useful for complex models. I’ve used it to choose between different parameterizations/formulations, and it provides the information that I’ve needed to prioritize which parts of a model to focus on for computational efficiency. It’s convenient that the data can be coerced into a data frame and munged without too much trouble.

I’ve experimented with OpenCL, but found that for the models I’m using it did not increase performance. I was able to get it to work with R + Docker, but that was a bit of process.

6 Likes

We rely heavily on reduce_sum. Prior to reduce_sum, our standard models would take about 4-5 days. This was too long for our modeling production cycle. So we kept using custom Gibbs samplers instead (1-2 days). Now with reduce_sum running 8 threads per chain, we see about 5x speedup over the non reduce_sum. Our sampling is now slightly faster than our Gibbs samplers and Stan converges much better. We get slightly faster estimations with 12 threads and for larger problems with 16-20, but rarely much after that. And of course too many threads eventually is slower. We have also tuned the grainsize faster than the Stan auto-default. For our standard problems, each thread runs about 4 loops for one iteration (8 threads = 32 data chunks).

All this is based on desktops with AMD threadripper 3970X (WSL), and AWS compute optimized instances (c5.4xlarge) running cmdstanr. I can’t thank Stan developers enough for reduce_sum. It took Stan from a non-production nice idea to a practical tool.
https://discourse.mc-stan.org/t/cmdstan-2-23-release-candidate-is-available/14301

5 Likes

Great to hear!

If you are using AWS EC2 instances, I would suggest trying out the c6g instances that use the ARM CPUs. We have seen the ARM instance run 2-3x times faster for Stan compared to similarlly priced x86 instances. As of 2.26.1, we release arm64 CmdStan tarballs and cmdstanr should install the correct tarball as well so should be simple to try out, as far as Stan goes (obviously dont know what other infrastructure you have around it, so this might not be that simple).

Might be worth exploring.

3 Likes

Very cool. We are currently running CmdStan 2.27.0 on AWS. Does that mean it should run with the ARM processor as is? Or is there something special in the compile statement? Currently we just have cpp_options = list(stan_threads = TRUE).

Sadly, on Sagemaker Studio (with JupyterLab) where I usually work there are no c6g instances (I just checked). But we also run Stan on EC2 with a limited RShiny interface. I’ll see if I can test it there.

within-chain parallelization in both stan and brms has been super useful to me, greatly helping model development, and made a paper I’m working on with a larger dataset possible. Thanks!

2 Likes

Recently I had a client trying to parallelize over just a few cores and they still have trouble achieving any speed up despite an accommodating model structure. They had nowhere to turn when the reduce_sum was significantly slower than the serial code.

I have a simple recommendation for this. We could simply compute the time spent in reduce_sum versus the time spent elsewhere, and use Amdahl’s law to get cheap and easy estimates of speedup. I previously did something similar to estimate the theoretically achievable maximum speedup on some StanCon models.

2 Likes

Just seeing this. Lemme run a sample again and post back.