Using the Apple M1 GPUs: question from a noob

Short summary of the problem

I’m trying to run a suite of spline regressions using BRMS on some very large datasets. It’s painfully slow: a few hours per model if I pre-average across items (totaling about 2.5 days once all models are run), around 10 hours per model if I do an actual mixed effects model.

I’ve got an Apple M1 (and also access to a cluster), so I’m trying to figure out whether I can speed things up – particularly whether I can make use of the GPUs. I have seen this post and this post, but they mostly talk about whether it would be possible, not how to do it.

Following the BRMS documentation (here), I tried:

fit.congruent <- brm(accuracy ~ s(age, by=congruent) + congruent + (1|id), 
                     family="bernoulli", 
                     prior <- c(set_prior("normal(0, 1)", class="b")),
                     control = list(max_treedepth = 15, adapt_delta = .95),
                     data=temp, iter=1000, chain=4, cores=4, opencl = opencl(c(0, 0)))

The output was … well, there actually isn’t any output. If I try to inspect the output, here’s what I see:

> str(fit.congruent)
>

which is … odd?

From poking around, I suspect there’s more that I need to do to get this thing to work, but I’m having trouble tracking down documentation. I mostly see somewhat oblique discussions in this forum (oblique to me – no doubt they make sense to power stan users). And some of the discussion seems to indicate that probably using the GPU wouldn’t do any good anyway … maybe?

Is this something I want to pursue? If I want to pursue it, how do I pursue it? Is there a point-by-point tutorial somewhere? Could somebody please make one?

  • Operating System: Big Sur
  • brms Version: 2.15.0
1 Like

Perhaps this issue: Using Stan's OpenCL support in brms · Issue #1166 · paul-buerkner/brms · GitHub could help. Bottom line, GPU support in brms is implemented in the current github version but not yet on CRAN.

1 Like

Before looking into clusters and GPUs it is good to rule out that the problem lies actually with your model (models that have issues are usually very slow). Since you have large max_treedepth and adapt_delta I presume you get divergent transitions without those. It is quite likely that figuring out how to fit the model reliably with less aggressive settigns will speed up your computation substantially.

I see two potential problems with your model:

  1. the s(age, by = congruent) already lets the response differ between the values of congruent, so the congruent term itself is likely redundant, causing non-identifiability. Does the model run fast with just one of those terms included? What does the pairs (or bayesplot::mcmc_pairs) plot look like when you include the b_congruentXX, bs_sage:congruentXX and some of the s_sagecongruentXX_1 terms?
  2. There might be too few observations per id or too few id values to inform the (1|id) term (it is especially problematic if you have just one observation per id in a bernoulli model)

Best of luck with your model!

3 Likes

Check out the code here for a similar model coded directly in Stan and using performance tricks that I don’t think are in brms

That should be “trick” singular; as the update I just posted there notes, I just realized that really the only thing doing the “work” there is the sufficient stats trick. And since that’s such a well-known and easy-to-automate trick, I wouldn’t be surprised if @paul.buerkner indeed already had it by default in brms.

By the way, you wouldn’t happen to be doing analysis of a measure of executive function from something like a stroop or flanker task, eh? (Guess cued by the name of your “congruent” variable) If so, I was literally doing the same thing this week, but with simultaneous inference on accuracy and a location-scale model of RT, and even with >1k participants (doing a mega-analysis) with a few hundred trials each, I was getting quality samples within an hour or so.

Even accounting for the fact that you’re doing a spline on age, you really shouldn’t be experiencing such slowness. I concur with @martinmodrak that it seems likely that something is dramatically off in terms of the mesh between the data and the model (inc priors therein). With such a simple model, I’m having trouble guessing what could even go wrong though…

Somehow I missed these responses. Thank you everyone who chimed in!

@martinmodrak: I read somewhere that because s() centers the variables, adding a main effect of ‘congruent’ would be advised. I can’t remember where I saw this, though. Does that not sound right?

@mike-lawrence I do have a lot of subjects. Depending on analysis, around 15,000. I don’t have a huge number of trials per subject, so that could well be an issue. Depending on the specific analysis, it would be 10-20 per level of ‘congruent’. @martinmodrak: How many should I want?

@martinmodrak I didn’t get much from looking at ‘pairs’. Maybe I don’t know what I’m looking at. I understand divergent transitions are supposed to show up in the graph, but for me I don’t think they do.

I unfortunately don’t have the model I described above saved, and it simply takes too long to run. Here’s a slightly simpler one that still takes forever:

  fit.CSE <- suppressMessages(brm(CSE ~ s(age), 
                    family="gaussian", 
                    prior <- c(set_prior("normal(0, 1)", class="b"),
                    control = list(max_treedepth = 15, adapt_delta = .95),
                    data=accByTypeByAge, iter=1000, chain=6))

I have been adjusting priors to get them into reasonable ranges for the data. That hasn’t obviously sped anything up. In terms of specifying a prior from a more convenient family … that’s well beyond my current knowledge.

I think that in theory the contributions of the congruent main effect and s(age, by = congruent) can be distinguished, but in practice I think the splines can absorb some of that variability leading to weak identifiability. My experience with this is limited, so it is possible I am wrong. If removing the term does not result in a speedup, it is unlikely this is the core of the issue.

That should be a pretty decent number. Decent speedup could also likely be gained by combining rows that have the same predictors and use binomial instead of bernoulli likelihood (If same id means the same age, you could get 10-20 times reduction in the number of rows while binomial likelihood is only a little more expensive per row than bernoulli)

When using mcmc.pairs from bayesplot you need to also pass np = nuts_params(fit) to see divergent transitions.

The key questions to move forward IMHO are:

  • Do you get divergent transition with default adapt_delta and max_treedepth?
  • If so, do you get them also when running with a small subset of the subjects (and default control parameters)? (so that the model takes short time to run and you can iterate on the solution more easily)

If you have a large model with relatively small number of predictor terms (where a random intercept counts as one term), you are likely to get very good inferences and huge speedups from the inlabru package which uses an approximation, but the approximation tends to hold pretty well if you have a lot of data.

Here’s a version of the original model that I didn’t try to run on a GPU. You’ll see I cut down on iterations and cores to get it to run quickly, but it stuck on warmup and I got bored of running it and quit:

> fit.congruent <- brm(accuracy ~ s(age, by=congruent) + congruent + (1|id), 
+                      data = flanker,
+                      family="bernoulli", 
+                      prior <- c(set_prior("normal(0, 1)", class="b"),
+                                 set_prior("normal(0, 1)", class="sds"),
+                                 set_prior("normal(1, 1)", class="Intercept")),
+                      control = list(max_treedepth = 15, adapt_delta = .95),
+                      iter=100, chain=1, cores=1)
Compiling Stan program...
'config' variable 'CPP' is deprecated
clang -mmacosx-version-min=10.13 -E
recompiling to avoid crashing R session
Start sampling

SAMPLING FOR MODEL '8c555557f993c4f8f41c97044783741e' NOW (CHAIN 1).
Chain 1: 
Chain 1: Gradient evaluation took 0.328451 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 3284.51 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1: 
Chain 1: 
Chain 1: WARNING: There aren't enough warmup iterations to fit the
Chain 1:          three stages of adaptation as currently configured.
Chain 1:          Reducing each adaptation stage to 15%/75%/10% of
Chain 1:          the given number of warmup iterations:
Chain 1:            init_buffer = 7
Chain 1:            adapt_window = 38
Chain 1:            term_buffer = 5
Chain 1: 
Chain 1: Iteration:  1 / 100 [  1%]  (Warmup)

I tried removing my “aggressive” parameter settings. After 20 minutes, here’s where I am:

> fit.congruent <- brm(accuracy ~ s(age, by=congruent) + congruent + (1|id), 
+                      data = flanker,
+                      family="bernoulli", 
+                      prior <- c(set_prior("normal(0, 1)", class="b"),
+                                 set_prior("normal(0, 1)", class="sds"),
+                                 set_prior("normal(1, 1)", class="Intercept")),
+                      iter=500, chain=2, cores=3)
'config' variable 'CPP' is deprecated
clang -mmacosx-version-min=10.13 -E
starting worker pid=41209 on localhost:11782 at 12:18:02.776
starting worker pid=41225 on localhost:11782 at 12:18:03.190

SAMPLING FOR MODEL '8c555557f993c4f8f41c97044783741e' NOW (CHAIN 1).
Chain 1: 
Chain 1: Gradient evaluation took 0.524757 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 5247.57 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1: 
Chain 1: 

SAMPLING FOR MODEL '8c555557f993c4f8f41c97044783741e' NOW (CHAIN 2).
Chain 2: 
Chain 2: Gradient evaluation took 0.475801 seconds
Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 4758.01 seconds.
Chain 2: Adjust your expectations accordingly!
Chain 2: 
Chain 2: 
Chain 1: Iteration:   1 / 500 [  0%]  (Warmup)
Chain 2: Iteration:   1 / 500 [  0%]  (Warmup)

Sorry – can you please explain how to do that? And yes, the same id always has the same age.

The definition of binomial is that it is a sum of independent bernoulli trials, so for each participant all the congruent and all the incongruent trials should be independent (the bernoulli model you have already assumes that), so if you condense your data to have for each id and congruent the number of trials in tottal and the number of success, you’ll get an exactly equivalent but faster model as:

brm( n_successes | trials(n_trials) ~ s(age, by=congruent) + congruent + (1|id), 
  family = "binomial", ...)

Does that make sense?

Maybe I gave the advice in bad order: the first thing to try is to reduce the dataset to say 20-50 subjects and see how the fit behaves - this will let you move much faster and most likely any issues should manifest also in this setting.

2 Likes

Thanks! I got the idea but wasn’t sure about the syntax.

I don’t know whether any of these error messages are relevant. But here’s what I get when I run with 500 subjects, 4 chains, & 2000 iterations

set.seed(24)

start<-Sys.time()
fit.congruent.bern <- brm(accuracy ~ s(age, by=congruent) + (1|id), 
                     data = flanker[flanker$id %in% levels(flanker$id)[1:500], ],
                     family="bernoulli", 
                     prior <- c(set_prior("normal(0, 1)", class="b"),
                                set_prior("normal(0, 1)", class="sds"),
                                set_prior("normal(1, 1)", class="Intercept")),
                     iter=2000, chain=4, cores=5)
print(start - Sys.time())

set.seed(24)
start<-Sys.time()

mydata <- flanker[flanker$id %in% levels(flanker$id)[1:500], ] %>%
  group_by(age, id, congruent) %>%
  summarise(n_successes = sum(accuracy, na.rm=TRUE), 
            n_trials = n()) 

fit.congruent.bin <- brm(n_successes | trials(n_trials) ~ s(age, by=congruent) + (1|id), 
                     data = mydata,
                     family="binomial", 
                     prior <- c(set_prior("normal(0, 1)", class="b"),
                                set_prior("normal(0, 1)", class="sd"),
                                set_prior("normal(0, 1)", class="sds"),
                                set_prior("normal(1, 1)", class="Intercept")),
                     iter=2000, chain=4, cores=5)
print(start - Sys.time())

The Bernoulli model took 31 minutes and had 7 divergent samples.

The binomial model took <2 minutes and also had 7 divergent samples. So that’s an argument for the binomial.

As far as divergences, here’s the figure for

mcmc_pairs(fit.congruent.bin, np = nuts_params(fit.congruent.bin))

TBD ← EDIT: realized I ran the wrong thing, and it’s taking a while to run mcmc_pairs()
EDIT2: mcmc_pairs() has been running for 20 minutes. Should that be happening?

What do you recommend for dealing with them? Also, I am still going to need to go back up to a much larger number of subjects. Additionally, this is just analyzing the accuracy data. I also have reaction time data. Is there a similar trick I can use there to speed things up?

FWIW I did try out log-normal models at some point. The problem I ran into there is that s(age) becomes a multiplicative factor rather than additive, which obviously didn’t work out very well. There’s probably (?) a way around that, but I’ve been taking log(RT) and running a plain vanilla Gaussian.

mcmc_pairs() ran for 35 minutes without terminating. So I quit out. Any suggestions on how to get those plots?

BTW @paul.buerkner it would be nice if the warning that says you should use pairs() provided a bit more detail or a link to a description of how to use pairs(). I spent a few hours googling the other day and eventually gave up.

The plots can take ages to create (and are almost unreadable) when you have more than 10-12 parameters. Sou you need to restrict the parameters (you can use parnames to get the names of all parameters). Usually you want to plot the global params + 1-2 examples from each group parameter (the spline coefficients and varying intercepts). The rstan::pairs is a bit faster (but more ugly and less customizable).

That gets harder - Mike (links above) did some work on “sufficient statistics” to get somewhat similar improvements, but it is AFAIK not supported by brms. In any case, inlabru and/or ADVI in Stan might be the way to go for such a large dataset - it is sitll useful to verify that you get similar results as with MCMC for smaller datasets.

We are currently working on an updated guidance for the warnings in Stan in general (hopefully goes public with next release)

2 Likes

I’m a little afraid of going down a rabbit hole on alternative packags like inlabru, since I’d have to learn that and it may still not work. So here’s a higher-level question. What about just using mgcv? I believe I can’t set priors anymore, but I wasn’t doing anything fancy with the priors anyway. I think mgcv isn’t nearly as flexible as brms: for instance, I think I can’t have correlated random effects … but I wasn’t using correlated random effects!

Would I be missing enough to justify the computational cost of brms or the learning cost of inlabru?

The caveat being that I guess I still have to figure out how to get predictions from mgcv ignoring the random effects, which I haven’t figured out just yet.

How many unique ages do you have?

About 70. I get peoples’ ages in whole years, and I’ve got data from roughly 10 to 80 years old, depending on the dataset.

Ok, so not out of the question to use a GP then (which I only mention bc I know GPs better than splines when it comes to coding them directly in Stan).

I’ve been modelling accuracy and location-scale log-RT recently so gimme a sec and I’ll post what I’ve been doing with an example. It doesn’t currently have a possibly-non-linear predictor like age, but I’d been planning to add that too anyway, so I’ll see if I can get that in quick.

In the interim I’d say that given the volume of data (and thereby the low influence weakly-informed priors will play), yes you could use just straight mgcv. Try it with no random effects at all first you get a sense of the compute time, but if it’s as fast as I’m guessing, try adding by-subject intercepts then uncorrelated by-subjects congruency effects and finally the correlation. Ultimately I take to heart Barr’s proscription to “keep it maximal”, so the full model should be considered the definitive inference output, but the incremental approach I suggest above should set your expectations around compute durations.

So even with 1,000 subjects and using GPs, it takes a couple minutes:

tempdat <- flanker[flanker$id %in% levels(flanker$id)[1:1000], ]

fit.bam <- bam(logRT ~ s(age, by=congruent, bs='gp') + s(id, bs='re'), data = tempdat, verbosePQL=TRUE)

How long it takes seems to grow quickly with amount of data, and I haven’t been able to track down a way of getting real-time output (so that I can tell whether it’s time to give up or worth waiting). Any ideas on that?

Awesome! I really appreciate that.