Yo @paul.buerkner, I was doing some simple regressions with brms (more or less duplicating the regressions here: https://arxiv.org/pdf/1802.00842.pdf)
votes | trials(N) ~ male +
(1 + male | state) + (1 + male | race) + (1 + male | educ) + (1 + male | age) + (1 + male | marstat) +
(1 | race:educ) + (1 | race:age) + (1 | race:marstat) +
(1 | educ:age) + (1 | educ:marstat) +
(1 | age:marstat)
male is a two valued covariate for sex, everything else are hierarchical grouping terms. So this is for an MRP thing where we think of people as landing in different bins, one bin each for every combination of (male, race, educ, age group, marriage status, state). In this regression, every row of the dataset corresponds to voting outcomes in one of these bins.
Since most of the model is hierarchical, a lot of the time is spent in the brms model evaluating the hierarchical bit, which in this case looks like:
for (n in 1:N) {
mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n] +
r_2_1[J_2[n]] * Z_2_1[n] +
r_3_1[J_3[n]] * Z_3_1[n] + r_3_2[J_3[n]] * Z_3_2[n] +
r_4_1[J_4[n]] * Z_4_1[n] +
r_5_1[J_5[n]] * Z_5_1[n] +
r_6_1[J_6[n]] * Z_6_1[n] + r_6_2[J_6[n]] * Z_6_2[n] +
r_7_1[J_7[n]] * Z_7_1[n] + r_7_2[J_7[n]] * Z_7_2[n] +
r_8_1[J_8[n]] * Z_8_1[n] +
r_9_1[J_9[n]] * Z_9_1[n] +
r_10_1[J_10[n]] * Z_10_1[n] +
r_11_1[J_11[n]] * Z_11_1[n] + r_11_2[J_11[n]] * Z_11_2[n];
}
There are a lot of rows between which only one group changes. Especially if we shuffle things correctly. In these cases, the mu would not need totally recomputed between rows, only updated.
Thereās a couple ways to do this, but if we just shuffle things carefully, then we can simplify the rule to defining whether things are different or not and whether we need to recompute to just being based on one of our input groups. So we can just include another vector variable called ārecomputeā that says if we need to recompute our grouping temporaries all together or we can just adjust the last value.
In this case I used state, and the new code looks like:
real base = 0.0;
for (n in 1:N) {
if(recompute[n] == 1) {
base = r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n] +
r_2_1[J_2[n]] * Z_2_1[n] +
r_3_1[J_3[n]] * Z_3_1[n] + r_3_2[J_3[n]] * Z_3_2[n] +
r_4_1[J_4[n]] * Z_4_1[n] +
r_5_1[J_5[n]] * Z_5_1[n] +
r_6_1[J_6[n]] * Z_6_1[n] + r_6_2[J_6[n]] * Z_6_2[n] +
r_7_1[J_7[n]] * Z_7_1[n] + r_7_2[J_7[n]] * Z_7_2[n] +
r_8_1[J_8[n]] * Z_8_1[n] +
r_9_1[J_9[n]] * Z_9_1[n] +
r_10_1[J_10[n]] * Z_10_1[n];
mu[n] += base +
r_11_1[J_11[n]] * Z_11_1[n] + r_11_2[J_11[n]] * Z_11_2[n];
} else {
mu[n] += base +
r_11_1[J_11[n]] * Z_11_1[n] + r_11_2[J_11[n]] * Z_11_2[n];
}
}
And you can compute the recompute variable beforehand by just figuring out between which rows the only thing that changes is J_11.
Doing that the model went from taking about 200 seconds to do 100 draws (including warmup) to about 80 seconds for the same. I checked the gradients and such in R and everything looked the same (I assume itās off in the last digits though).
Itās easy enough to do a little calculation to see hypothetically how many recomputes weād need with different orderings of the input:
adjustments recomputes last_variable
<int> <int> <chr>
1 3272 2935 age
2 3686 2521 educ
3 2024 4183 male
4 2642 3565 marstat
5 5796 411 state
6 2613 3594 race
So state required the fewest full recomputes so it benefited the most from this speedup. I hacked up one that did it based on education as well and the inference took about 140 seconds for the same 100 draws. That code looked a bit different:
real base = 0.0;
for (n in 1:N) {
if(recompute[n] == 1) {
base = r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n] +
r_2_1[J_2[n]] * Z_2_1[n] +
r_6_1[J_6[n]] * Z_6_1[n] + r_6_2[J_6[n]] * Z_6_2[n] +
r_7_1[J_7[n]] * Z_7_1[n] + r_7_2[J_7[n]] * Z_7_2[n] +
r_8_1[J_8[n]] * Z_8_1[n] +
r_10_1[J_10[n]] * Z_10_1[n] +
r_11_1[J_11[n]] * Z_11_1[n] + r_11_2[J_11[n]] * Z_11_2[n];
mu[n] += base +
r_3_1[J_3[n]] * Z_3_1[n] + r_3_2[J_3[n]] * Z_3_2[n] +
r_4_1[J_4[n]] * Z_4_1[n] +
r_5_1[J_5[n]] * Z_5_1[n] +
r_9_1[J_9[n]] * Z_9_1[n];
} else {
mu[n] += base +
r_3_1[J_3[n]] * Z_3_1[n] + r_3_2[J_3[n]] * Z_3_2[n] +
r_4_1[J_4[n]] * Z_4_1[n] +
r_5_1[J_5[n]] * Z_5_1[n] +
r_9_1[J_9[n]] * Z_9_1[n];
}
}
This transformation wasnāt easy enough for me to automate for my own models, so I guess Iāll just skip it. It looks like something that could conceivably be automated since youāre doing code generation though so I figured Iād write that up. I donāt know how useful this sorta transformation would be in general. Itās kinda ideal for my case since I just have tons and tons of groupings.
Iām already doing the thing where you group a bunch of bernoullis into binomials ā I guess this is just the next step haha.
Here are the models. Hope the data is right ā my scripts were pretty sketchy:
base.stan (8.4 KB) base.data.R (682.7 KB)
state.stan (10.2 KB) state.data.R (700.9 KB)
educ.stan (10.2 KB) educ.data.R (700.9 KB)