Comments on this thread got me thinking about speeding up likelihood evaluation for hierarchical models: https://github.com/stan-dev/stan/issues/2912
The calculations are memory bound right now (see 1, 2).
Excluding multiple node parallelism, if we want to speed up a memory bound calculation we can:
- Use the GPU for more memory bandwidth
- Change our calculations to use less memory
- Change our calculations to not be memory bound
I think there are ways to attack each of these points. I’ll list them here.
A basic hierarchical likelihood is something like:
Where:
y is a length N array of observations
X is an NxK matrix of covariates
\beta is a length K vector of non-hierarchical parameters
\alpha is a length I vector of hierarchical parameters
i is a length N array of indices into alpha in the range 1 to I
The \ldots s area to indicate there can be many hierarchical terms.
The hierarchical terms commonly get more complicated in two ways. First, the term above corresponds to the lme4 syntax:
(1 | group)
We might also have:
(z | group)
Where z
is some covariate. A model with such a term expands to:
z here is a length N vector that elementwise scales the hierarchical term
Another way hierarchical models get more complicated are interactions in the groups, i.e.
(1 | group1:group2)
In this case, each \alpha might have multiple indices:
This can be reordered to be the first, but this notation will be more efficient (no need to create multi-indexes).
The distribution itself will probably change as well. The normal distribution is just being used as a placeholder.
A GPU interface for hierarchical models
Alright, so if we want to run something like:
on GPUs, it’s fine. The bulk of the work is the matrix-vector multiply, and that is something that is cleanly implemented on the GPU.
We already have the glm interfaces that work for likelihoods like:
where \hat{\alpha} is the indexed version of alpha above (like, all the n terms \alpha_{i[n]} in a single length N vector).
That interface is:
real normal_id_glm_lpdf(real y | matrix x,
real alpha,
vector beta,
real sigma)
I think we can expand this interface to support all the hierarchical models above. The problem currently is we write the hierarchical terms with for loops, and we can’t put that on the GPU for free. The new thing is to develop an interface to sum all the random effects, so if the above computations were broken in two, something that can compute the \mu terms here:
The Stan signature would be:
real hierarchy(int[,] idxs, args...)
which probably isn’t too exciting, but what we can do is use function overloading to peel out groups of arguments that define each hierarchical term.
I’m switching to Stan-pseudocode since we don’t have function overloading or a way to handle variadic arguments other than special cases, but presumably this bit would be written in C++. If we have the overloads:
vector hierarchy(int[,] idxs,
vector alpha, int m,
args...) {
return alpha[idxs[m]] + hierarchy(idxs, args...);
}
This corresponds to peeling off a term: \alpha_{i[n]}. There are a couple new assumptions we need to make:
M is the number of hierarchical terms in the model
idxs
is a length M array of length N arrays of indices
m
is the index into idxs of which index alpha
should use.
vector hierarchy(int[,] idxs,
vector z, vector alpha, int m,
args...) {
return z .* alpha[idxs[m]] + hierarchy(idxs, args...);
}
And this corresponds to peeling off a term: z_n \alpha_{i[n]}
vector hierarchy(int[,] idxs,
matrix alpha, int m1, int m2,
args...) {
vector[N] mu;
for(n in 1:N) {
mu[n] = alpha[idxs[m1, n], idxs[m2, n]];
}
return mu + hierarchy(idxs, args...);
}
And this corresponds to peeling off a term: \alpha_{i[n], j[n]}. In this case, alpha
is a matrix and has two indices, so there are two m terms.
vector hierarchy(int[,] idxs,
vector z, matrix alpha, int m1, int m2,
args...) {
vector[N] mu;
for(n in 1:N) {
mu[n] = z[n] * alpha[idxs[m1, n], idxs[m2, n]];
}
return mu + hierarchy(idxs, args...);
}
And this peels off a term: z_n \alpha_{i[n], j[n]}.
The implementation should be different from what is written there, but the point is it’s possible to write down an interface for the hierarchical terms above so that we can work on an efficient implementation for an expanded hierarchical GLM function which could look like:
real normal_id_glm_lpdf(real y | matrix x,
vector beta,
int[,] idxs,
args...,
real sigma)
which would handle all the different hierarchical likelihood variations described above.
Just as an example, we might have the regression:
y ~ x0 + x1 + (1 + z | g1) + (1 | g1:g2)
This could translate to the Stan code:
target += normal_id_glm_lpdf(y | X, beta,
{ gidx1, gidx2 },
alpha1, 1,
z, alpha2, 1,
alpha3, 1, 2,
sigma)
where X holds the covariates x0
and x1
, gidx1
and gidx2
are length N indices of the group ids for each outcome, alpha1
is the group1 hierarchical intercepts, alpha2
are the group1 hierarchical slopes, and alpha3
are the hierarchical intercepts for the group1:group2 interaction.
Presumably with this interface defined we could think about how to implement this on the GPU. GPUs have like 10x the memory bandwidth as CPUs.
Using Less Memory
If there are N observations, K non-hierarchical parameters, M indices, and Z hierarchical slopes, then the forward evaluation requires roughly:
N * K + N * M + N * Z reads (this assumes N is much larger than the number of random effects and so is dominating the memory requirements)
We might not be able to get away with using floats in our calculations, but I don’t see why we couldn’t get away with making all our covariates (X and all the zs) floats. This would make our matrix-vector products twice as fast, because we are reading half the memory, and also lower the memory needs for the hierarchical slopes.
We could also make integers in the data block shorter versions if the limits are statically defined. Probably most hierarchical models we will work with in the short term have less than 65536 groups. In this case, all the group indices could be uint16s, which would cut the memory requirements for reading the indices by four.
This is safe, because we always do size checks on reads. So if the data block says:
data {
int<lower = 1, upper = 1000> gidx1;
int<lower = 1, upper = 50> gidx2;
}
We use uint16s for gidx1 and uint8s for gidx2.
Unfortunately we wouldn’t be able to speed up:
data {
int G;
int<lower = 1, upper = G> gidx1;
}
Change our calculations to not be memory bound
Alright now this is the hard one, but also maybe the coolest.
So just the basic linear regression is a good enough example to get memory bound computations:
We see that here.
This is because a matrix-vector multiply is memory bound. The way to not be memory bound with matrix-vector multiplies, is to try to rearrange your calculations to be matrix-matrix multiplies.
matrix-matrix multiplies are computation bound, and so it is possible to do them much more efficiently (and make use of all cores on a computer, etc.).
Now first of all, we do have multiple matrix-vector multiplies we need to do.
If \beta_1, \beta_2, \beta_3, … correspond to the \beta parameters on different chains, then we need to do the matrix-matrix multiply:
This would work perfectly fine with our autodiff. Numerically none of the terms from the different chains interfere with each other – we could break it all apart afterwards easily.
The problem is this is not how our models get compiled to C++. Our models get compiled as single chains. I think we should look at introducing a megamodel concept that computes likelihoods for multiple chains at a time and brings them together.
This would change how the samplers interfaced with the model class (multiple chains would need to be coordinated), but the model interfaces would need to be expanded from things that look like:
real log_prob(vector params);
To things that look like:
real[] log_prob(vector[] params);
And the compiler would need to be able to recognize situations where it could promote matrix-vectors into matrix-multiplies. Also the hierarchical terms I defined above should have sparse matrix-matrix versions (they are currently basically sparse matrix-vectors), but that is a detail for now.
Anyhow, that’s it for the description, but as an example of the possible utility of this I wrote a model that is the likelihood for C independent chains together, and compared the efficiency of that running that model on one core against running a C = 1 version of that model on C cores with C chains.
The comparison isn’t quite accurate, cause in the first model with the C likelihoods, the u-turn condition will have to wait on all the independent things to u-turn, which is a bit weird, but whatever.
So here’s the model:
data {
int N; // N data points
int K; // K parameters
int C; // C chains
matrix[N, K] X;
real y[N];
}
parameters {
matrix[K, C] beta;
}
model {
matrix[N, C] mu = X * beta;
for(c in 1:C) {
y ~ normal(mu[, c], 1.0);
}
}
Here’s code to run it:
library(tidyverse)
library(posterior)
library(cmdstanr)
library(rstan)
N = 4000
K = 50
C = 6
beta = rnorm(K)
X = matrix(rnorm(N * K), nrow = N)
y = rnorm(N, X %*% beta, 1.0)
m = cmdstan_model("megamodel.stan")
system.time({ f1 <- m$sample(list(N = N,
K = K,
C = 1,
X = X,
y = y), num_cores = C, num_chains = C, save_extra_diagnostics=TRUE) })
system.time({ f3 <- m$sample(list(N = N,
K = K,
C = C,
X = X,
y = y), num_cores = 1, num_chains = 1, save_extra_diagnostics=TRUE) })
Running 6 chains on 6 cores takes 27s on my computer. Running the 6 chain 1 core megamodel takes 46s.
Which is cool. So we did the same thing with 1/6 the computation in twice the time. So 3x speedup.
Now, I’m running some big calculations on the other two cores of my computer I didn’t want to stop (they’ve been running for like 10 hours now), so this comparison is probly a bit off, but whatever.
Since the u-turn condition in the megamodel is weird (the megamodel did about twice as many leapfrogs as the other), we can also compare the ESS/s for each model.
We can do this with posterior with (arranging to show lowest bulk ESS):
# Multiple chains single model
f1$summary() %>%
arrange(ess_bulk)
# Single chain multi-model
f3$draws() %>%
as_draws_df() %>%
pivot_longer(cols = starts_with('beta'),
names_to = c("n", "c"),
names_pattern = "beta\\[([0-9]+),([0-9]+)\\]",
names_ptypes = c(integer(), integer()),
values_to = "beta") %>%
mutate(.chain = c) %>%
select(-c, -.draw) %>%
pivot_wider(names_prefix = "beta",
names_from = n,
values_from = beta) %>%
summarise_draws() %>%
arrange(ess_bulk)
For the 6 core 6 chain run we get:
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -2044. -2044. 4.98 4.95 -2053. -2037. 1.00 2510. 3517.
2 beta[24,1] -1.68 -1.67 0.0161 0.0163 -1.70 -1.65 1.00 8005. 4342.
3 beta[28,1] 0.157 0.157 0.0155 0.0151 0.132 0.183 1.00 8681. 4619.
For the 6 chain 1 core megamodel we get:
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -12266. -12267. 12.8 13.0 -12287. -12245. 1.00 1789. 2386.
2 beta47 -0.622 -0.623 0.0159 0.0157 -0.649 -0.596 1.00 14946. 3768.
3 beta4 0.448 0.448 0.0158 0.0156 0.422 0.474 1.00 15074. 3927.
So if we ignore lp_ and the ess_tails (cuz we’re doing wishful thinking), we actually get 300 ESS/s for the 6 core 6 chain model and 320 ESS/s for the 1 core 6 chain megamodel.
So we did better than 6 cores on 1 core. This is for the simplest model possible, so take it with a grain of salt, but that’s pretty cool.
To check the treedepths for the 6 core 6 chain model:
f1$sampler_diagnostics() %>% summarise_draws()
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
3 treedepth__ 3 3 0 0 3 3 NA NA NA
4 n_leapfrog__ 7.00 7 0.103 0 7 7 1 6024. 6024.
For the 6 chain 1 core megamodel:
f1$sampler_diagnostics() %>% summarise_draws()
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
3 treedepth__ 4 4 0 0 4 4 NA NA NA
4 n_leapfrog__ 15 15 0 0 15 15 NA NA NA
And we can benchmark the gradient calls themselves with Rstan:
# Check the performance of the gradients
m2 = stan_model("megamodel.stan")
# Make N really big so each gradient takes a while so timings work
N2 = 400000
X2 = matrix(rnorm(N2 * K), nrow = N2)
y2 = rnorm(N2, X %*% beta, 1.0)
f4 = sampling(m2, iter = 1, data = list(N = N2,
K = K,
C = 1,
X = X2,
y = y2))
f5 = sampling(m2, iter = 1, data = list(N = N2,
K = K,
C = C,
X = X2,
y = y2))
beta = rnorm(get_num_upars(f4))
system.time({ grad_log_prob(f4, beta) })
system.time({ grad_log_prob(f5, rep(beta, C)) })
The single chain model takes about 80ms to evaluate a gradient. The megamodel takes about 160ms.
This would all be contingent on how possible it is to do static analysis and generate code for efficiently evaluating multiple chains together.
Anyway the overall efficiency here would go up as we added more chains, though we’d still be memory bound unless we started running hundreds of chains. In that case we’d probably be losing efficiency trying to piece things back together and we presumably wouldn’t have our within-chain diagnostics, etc.
Anyway, I think we could think about these things in terms of:
- GPU gives us 10x memory bandwidth
- We could maybe require half the memory for the calculations we’re currently doing
- We could code-gen megamodels to make the memory bottleneck less of an issue
Long post. Hope that stuff is right. @Bob_Carpenter, @rok_cesnovar, @stevebronder, @tadej, @avehtari, @wds15, @yizhang
Edit: Fixed m1/m2 indexing error pointed out by @stevebronder here