Group orderings in regression

Yo @paul.buerkner, I was doing some simple regressions with brms (more or less duplicating the regressions here: https://arxiv.org/pdf/1802.00842.pdf)

votes | trials(N) ~ male +
        (1 + male | state) + (1 + male | race) + (1 + male | educ) + (1 + male | age) + (1 + male | marstat) +
        (1 | race:educ) + (1 | race:age) + (1 | race:marstat) +
        (1 | educ:age) + (1 | educ:marstat) +
        (1 | age:marstat)

male is a two valued covariate for sex, everything else are hierarchical grouping terms. So this is for an MRP thing where we think of people as landing in different bins, one bin each for every combination of (male, race, educ, age group, marriage status, state). In this regression, every row of the dataset corresponds to voting outcomes in one of these bins.

Since most of the model is hierarchical, a lot of the time is spent in the brms model evaluating the hierarchical bit, which in this case looks like:

for (n in 1:N) {
  mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n] +
    r_2_1[J_2[n]] * Z_2_1[n] +
    r_3_1[J_3[n]] * Z_3_1[n] + r_3_2[J_3[n]] * Z_3_2[n] +
    r_4_1[J_4[n]] * Z_4_1[n] +
    r_5_1[J_5[n]] * Z_5_1[n] +
    r_6_1[J_6[n]] * Z_6_1[n] + r_6_2[J_6[n]] * Z_6_2[n] +
    r_7_1[J_7[n]] * Z_7_1[n] + r_7_2[J_7[n]] * Z_7_2[n] +
    r_8_1[J_8[n]] * Z_8_1[n] +
    r_9_1[J_9[n]] * Z_9_1[n] +
    r_10_1[J_10[n]] * Z_10_1[n] +
    r_11_1[J_11[n]] * Z_11_1[n] + r_11_2[J_11[n]] * Z_11_2[n];
}

There are a lot of rows between which only one group changes. Especially if we shuffle things correctly. In these cases, the mu would not need totally recomputed between rows, only updated.

There’s a couple ways to do this, but if we just shuffle things carefully, then we can simplify the rule to defining whether things are different or not and whether we need to recompute to just being based on one of our input groups. So we can just include another vector variable called ‘recompute’ that says if we need to recompute our grouping temporaries all together or we can just adjust the last value.

In this case I used state, and the new code looks like:

real base = 0.0;
for (n in 1:N) {
  if(recompute[n] == 1) {
    base = r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n] +
      r_2_1[J_2[n]] * Z_2_1[n] +
      r_3_1[J_3[n]] * Z_3_1[n] + r_3_2[J_3[n]] * Z_3_2[n] +
      r_4_1[J_4[n]] * Z_4_1[n] +
      r_5_1[J_5[n]] * Z_5_1[n] +
      r_6_1[J_6[n]] * Z_6_1[n] + r_6_2[J_6[n]] * Z_6_2[n] +
      r_7_1[J_7[n]] * Z_7_1[n] + r_7_2[J_7[n]] * Z_7_2[n] +
      r_8_1[J_8[n]] * Z_8_1[n] +
      r_9_1[J_9[n]] * Z_9_1[n] +
      r_10_1[J_10[n]] * Z_10_1[n];
    mu[n] += base +
      r_11_1[J_11[n]] * Z_11_1[n] + r_11_2[J_11[n]] * Z_11_2[n];
  } else {
    mu[n] += base +
      r_11_1[J_11[n]] * Z_11_1[n] + r_11_2[J_11[n]] * Z_11_2[n];
  }
}

And you can compute the recompute variable beforehand by just figuring out between which rows the only thing that changes is J_11.

Doing that the model went from taking about 200 seconds to do 100 draws (including warmup) to about 80 seconds for the same. I checked the gradients and such in R and everything looked the same (I assume it’s off in the last digits though).

It’s easy enough to do a little calculation to see hypothetically how many recomputes we’d need with different orderings of the input:

   adjustments recomputes last_variable
        <int>      <int> <chr>        
1        3272       2935 age          
2        3686       2521 educ         
3        2024       4183 male         
4        2642       3565 marstat      
5        5796        411 state        
6        2613       3594 race   

So state required the fewest full recomputes so it benefited the most from this speedup. I hacked up one that did it based on education as well and the inference took about 140 seconds for the same 100 draws. That code looked a bit different:

real base = 0.0;
for (n in 1:N) {
  if(recompute[n] == 1) {
    base = r_1_1[J_1[n]] * Z_1_1[n] + r_1_2[J_1[n]] * Z_1_2[n] +
      r_2_1[J_2[n]] * Z_2_1[n] +
      r_6_1[J_6[n]] * Z_6_1[n] + r_6_2[J_6[n]] * Z_6_2[n] +
      r_7_1[J_7[n]] * Z_7_1[n] + r_7_2[J_7[n]] * Z_7_2[n] +
      r_8_1[J_8[n]] * Z_8_1[n] +
      r_10_1[J_10[n]] * Z_10_1[n] +
      r_11_1[J_11[n]] * Z_11_1[n] + r_11_2[J_11[n]] * Z_11_2[n];
    mu[n] += base +
      r_3_1[J_3[n]] * Z_3_1[n] + r_3_2[J_3[n]] * Z_3_2[n] +
      r_4_1[J_4[n]] * Z_4_1[n] +
      r_5_1[J_5[n]] * Z_5_1[n] +
      r_9_1[J_9[n]] * Z_9_1[n];
  } else {
    mu[n] += base +
      r_3_1[J_3[n]] * Z_3_1[n] + r_3_2[J_3[n]] * Z_3_2[n] +
      r_4_1[J_4[n]] * Z_4_1[n] +
      r_5_1[J_5[n]] * Z_5_1[n] +
      r_9_1[J_9[n]] * Z_9_1[n];
  }
}

This transformation wasn’t easy enough for me to automate for my own models, so I guess I’ll just skip it. It looks like something that could conceivably be automated since you’re doing code generation though so I figured I’d write that up. I don’t know how useful this sorta transformation would be in general. It’s kinda ideal for my case since I just have tons and tons of groupings.

I’m already doing the thing where you group a bunch of bernoullis into binomials – I guess this is just the next step haha.

Here are the models. Hope the data is right – my scripts were pretty sketchy:

base.stan (8.4 KB) base.data.R (682.7 KB)
state.stan (10.2 KB) state.data.R (700.9 KB)
educ.stan (10.2 KB) educ.data.R (700.9 KB)

Thinking more about this, I guess all I’m doing here is breaking the calculations into two stages, one where I compute what the logit scale group mean would be for the smaller set of bins:

(male, race, educ, age group, marriage status)

And then using that to compute the means for:

(male, race, educ, age group, marriage status, state)

Because there are a lot of states, this saves a lot of time in this specific case.

Maybe it’d be possible to specify what’s going on manually with something like:

bf(votes | trials(N) ~ offsets + (1 + male | state),
    offsets ~ male +
      (1 + male | race) + (1 + male | educ) + (1 + male | age) + (1 + male | marstat) +
      (1 | race:educ) + (1 | race:age) + (1 | race:marstat) +
      (1 | educ:age) + (1 | educ:marstat) +
      (1 | age:marstat), nl = TRUE)

Cause offsets doesn’t need to be as long as votes.

I think there’s a way to code the contributions of discrete-level random effects efficiently with a single function (that could have custom autodiff and such).

The arguments would be:

int N; // Number of data points
int M; // Number of different factors
int L; // Length of vector which contains all random effects
       //   concatenated together
int<lower = 1, upper = L> membership[N, M];
   // For each row (data point)
   //   mu[n] = sum_over_m(parameters[membership[n, m]])
   //   where m goes from 1 to M
int<lower = 1, upper = N> base_value[N, M];
   // If we're careful about how membership is sorted, we
   //   can reuse a lot of previous partial sums -- this
   //   is a 2d array letting us which partial sums we
   //   need
int<lower = 1, upper = N> number_unique_rows[M];
   // number_unique_rows[m] is the number of unique rows
   //   in the matrix membership[, 1:m]
   //
   // If membership is sorted so that
   //   membership[1:number_unique_rows[m], 1:m] contains all
   //   the unique rows for the first m columns, then the partial
   //   sums sum_over_l(parameters[membership[1:number_unique_rows[m], l]])
   //   where l goes from 1 to m can be reused in column m + 1
vector[L] parameters; // All random effects concatenated together

It’s like a sparse matrix vector multiply where we order things so that we’re reusing a lot of calculations. It could be totally eclipsed by a fast sparse matrix vector product (which easily handles continuous covariates as well), I dunno.

Haven’t exactly coded this up. Just leaving it half-baked. @seantalts

We do it that way in rstanarm but we use R functions in the lme4 package to build the sparse matrix.

I vaguely remember a long time ago that there was some weird performance issue with sparse matrices? Like they’re implemented in rstanarm in a way that wouldn’t be okay in Stan itself cause some safety checks are turned off? Something like that?

I’m a little scared that that’s a significant part of the performance gain here is less adds -> less index checking or something weird.

We circumvented the checking by defining our own C++ function (and using scalar autodiff rather than analytical gradients):