[Case-study preview] Speeding up Stan by reducing redundant computation

Hey folks, I mentioned during the general meeting last week that I was working on a case study walking through an idea I had for speeding things up in Stan for models with contrast matrices with high redundancy. I’m still working on the prose part, but I thought I’d post a link to the code in case anyone wants to take a peek. I’ve included comments in both the R and Stan files to try to describe what’s going on.

The data scenario I use as an example is actually the same as from the “Multivariate Priors for Hierarchical Models” section of the Stan User’s Guide, with the tweak that I use terminology from my neck of the woods (Psychology). So the data come from multiple human “subjects”, each of which is randomly assigned to one of two groups, then all subjects are observed multiple times in each of two conditions.

When there are 100 participants and 100 observations in each condition per participant, the “_slower” model (pretty much just the final optimized model from the manual) takes about 25 minutes to achieve 4e3 samples while the “_faster” model takes about 4 minutes.

The geometry of the parameter space should be identical between these models, all I’m doing is identifying computations that are redundant, doing them once and indexing into them rather than computing them fresh every time. So there shouldn’t ever be any cost to using this trick and I think it should be possible to automate it in packages like rstanarm and brms.

21 Likes

Just had a thought while procrastinating on the prose: design matrices tend to have high regularity within columns (ex. intercept column, effect columns where entries are simply *-1 of each other) so I wonder if that regularity could be used to further reduce computation, or whether dot_product() is already super optimized that breaking it apart and doing the sums explicitly would just slow things down.

@mike-lawrence The trick here is that there are multiple observations with sort of the same right hand side, right?

This sounds similar to the thing where you replace a bunch of bernoullis that all have the same right hand side with a binomial.

I did some thinking about a similar thing at some point. I decided for intercept only models (or at least for the intercept part of a hierarchical regression) we could get somewhere, but I never wrestled it into a form where I could write a clean function we might put in Math (Group orderings in regression).

@andrewgelman is often forwarding e-mails from people who have large hierarchical regressions and they’re happy with the regression but it’s just too slow cause they have an absurd number of data points.

Since you already started a case study, do you think you could write and show examples of these other things to? Like expand the contents to “Last Mile Optimizations for Hierarchical Regressions” maybe? The idea being, after everything else is done, these are things you can do to make your model faster. They don’t work for everyone, so it’s hard to make them general, but there is a checklist of things people can look through.

In that theme, I think it would make sense to add a section on using the GLMs (https://mc-stan.org/docs/2_22/functions-reference/bernoulli-logit-glm.html) and reduce_sum. So that’s 5 things, which is a pretty sizeable case study.

And this is purposefully separate from the model approximation things you might try to do (or switching to lme4 that ends up happening a lot).

2 Likes

I hadn’t thought of it precisely like that but yes that’s correct and a good way to describe it succinctly/generalizably.

I’d already been planning to add a second model with largely the same structure but with a binomial outcome to add the sufficient statistics trick as an additional speed-up in that situation. Are there any other optimization you were thinking of?

1 Like

Ooooh right, that’s the name for this stuff. I’d forgotten.

So multivariate normal + bernoulli/binomial

And the other thing I was talking about is there’s redundant calculation on intercepts a lot.

So for instance, if you have 4 age groups, 5 income groups, and 50 states and you index age with i, income with j, and states with k, then you might have a bunch of intercepts:

\mu_i + \mu_j + \mu_k

And then there are 20 * 50 = 1000 possible right hand sides there, but like your survey could definitely have been more than 1000 people! So you could get savings in kinda the same way.

But that doesn’t always work, like in a lot of hierarchical models you would have more groups than people. If we have N < 1000 rows of a dataframe and our model includes these terms:

\mu_i + \mu_j + \mu_k

That will take 3 * N loads and 2 * N adds. But! What if we break this in to two pieces:

\mu_{ij} = \mu_i + \mu_j

\mu_{ij} + \mu_k

Now the first thing takes 2 * 20 loads + 20 adds, and the second thing takes 2 * N loads and N adds, which is less work.

You could apply this recursively to reduce the amount of copying you gotta do. These hierarchical models are memory bound so this is significant. You’re moving around 2/3 of the memory so you’d expect a comparable speedup!

And then also I think always feeding whatever you do into the appropriate glm and maybe using reduce_sum makes sense. I believe the glms work via some sufficient statistics stuff too.

Yes, I’m totally into this hypothetical document explaining how to fit big hierarchical regressions without having to run the computer overnight.

1 Like

@andrewgelman I’ve clearly failed to make time to write this up. Maybe a good project (parsing what I’ve done, adding prose) for one of your many aspiring students? I think it’s upper-year undergraduate or maybe early-year graduate level. And possibly extra credit for adding it to rstanarm/brms as something that’s done for the user automatically.

3 Likes

All volunteer work is successful as long as you enjoyed doing it! Finishing projects are for work, yuck.

4 Likes

AKA i.i.d.? (Two TLAs in a single A-only sentence.) I discuss this in the efficiency chapter of the user’s guide in the section Exploiting sufficient statistics, which is another way to think about it. The sum of bernoulli outcomes is sufficient if they are i.i.d.

If it can be done in transformed data, it only happens once and there’s almost no relative overhead compared to model fitting.