Stan backend for NumPyro + performance comparison

Update: Another paper came out recently that does some pretty neat stuff with the Stan compiler so it can build NNs! (cc’ing @s.maskell who sent it during the meeting today). Table 3 has some benchmarks checking numpyro against the stan examples and tl;dr we really need to update the examples in posterior db. I filed an issue about it here but it mostly comes down to we need to update these models to use some of the newer more performant code. pretty much all the models they best Stan in do a bunch of real * vector multiplies that could pretty easily be made into matrices and use normal_id_glm(). Though I’d be curious then if they also beat us in those because I’m pretty sure NumPyro uses the new style of matrix we are currently building out in Stan math.

Another + is I think fixing up some of those models to avoid so many loops would let them benchmark against a wider set of models. They seem to not be able to do multiple inner loops?

I think there is also a miscalculation in table 3’s performance difference calculation and it should be (stan_time/numpyro_time) - 1. Like for eight_schools_noncentered Stan runs in 0.02 seconds and numpyro runs in 0.07 seconds but they report a speedup of 0.29 but it’s actually a slowdown of -71%. Also they have a few places with an accuracy mismatch but sometimes they report the speedup and other times they dont?

I’m not sure where they are getting the info for RQ2: Accuracy where they compare MCSE of the sample mean is within ± 30% of the standard deviation?* For Stan math we compare the mean down to 0.0001? I also think comparing against two separate inference algorithms you want to actually check that things like ESS and tail ESS are coming out as you would expect. I would refer to bob’s comment

* Actually looking at the cmdstan performance tests we do check that MCSE is within ±30% of the gold tests and we check that the mean’s are no more than 1*10-4 difference. Why did they only do the one? Also, at the meta level, who dives so deep down into someones code to write benchmarks and doesn’t hit them up about it???

It’s not directly relevant to their paper but I wished they posted code because it would actually be really neat to take the amount of time calculating gradients, subtract that from the overall runtime, and see how NumPyro’s NUTS impl compares to Stan. They use a non-recursive version of NUTS and I’ve always been curious how that performs against a set of problems.

I also don’t really understand figure 10 as well. Is this Stan with the compiler they wrote?

Overall though the paper is neat and if they wanted to post a design doc to integrate the numpyro stuff into Stan’s compiler I think that would be cool. I think it would be nice to break it up into two seperate things, the numpyro backend and then the NN specific blocks

Also when they are commenting about the example models

Since these are official and long-standing examples, we assume that they use the non-generative features on purpose. Comments in the source code further corroborate that the programmer knowingly used the features. While some features only occur in a minority of models, their prevalence is too high to ignore

Oooof. I’ll refer to @bob_carpenter 's comment

Those models are super old and are just made to show how you could do it. Do we need to put a disclaimer on those examples like, “Hi some of these examples are outdated but show what is possible in terms of modeling within the Stan language, for any questions please contact [SGB email to refer to someone]”? Has anyone on the Stan dev team been talking to these folks?

9 Likes

Wonder if gathering examples from users, optimizing them and adding them to posteriorDb would make for a good Google Summer of Code project?

6 Likes

Thank you for looking and commenting on our paper!

The original motivation of this work is not about speedups. We were interested in the relationship between Stan and generative probabilistic programming languages and also see how to import features from some generative PPLs to Stan by building a compiler. We decided to take Pyro as a backend to see how to import variational inference with explicit guides and NN.

The performances results are recent and came as a surprise. We are collaborating with @s.maskell since the last StanCon and when he saw the results he suggested that we update our Arxiv paper to start a conversation. We are happy to have feedback on our results.

pretty much all the models they best Stan in do a bunch of real * vector multiplies that could pretty easily be made into matrices and use normal_id_glm().

We also noticed that you have some static checks which are limiting the use of tensor operations (stanc3/Stan_math_signatures.ml at master · stan-dev/stanc3 · GitHub).

I think there is also a miscalculation in table 3’s performance difference calculation and it should be (stan_time/numpyro_time) - 1. Like for eight_schools_noncentered Stan runs in 0.02 seconds and numpyro runs in 0.07 seconds but they report a speedup of 0.29 but it’s actually a slowdown of -71%. Also they have a few places with an accuracy mismatch but sometimes they report the speedup and other times they dont?

You are right that table 3 needs some clean up. The speedups only compares Stan with the result of the “comprehensive” translation to NumPyro when the results match. We are missing in particular the speedup for hmm_example. Presenting the speedup as stan_time/numpyro_time, I think is pretty standard. A value less than 1 simply indicates that Stan is faster.

Another + is I think fixing up some of those models to avoid so many loops would let them benchmark against a wider set of models. They seem to not be able to do multiple inner loops?

Yes, we can only support a subset of nested loops with the Numpyro backend. NumPyro has one level of loops. So for inner loops, we are using Jax loops, but then it does not work if one of these loop contains an observe.

  • Actually looking at the cmdstan performance tests we do check that MCSE is within ±30% of the gold tests and we check that the mean’s are no more than 1*10-4 difference. Why did they only do the one?

The comparison that we are doing is comp = sm[(sm["err"] > 0.0001) & (sm["err"] / sg["std"] > 0.3)] and I think that you are doing if err > 0.0001 and (err / stdev) > 0.3:.

It’s not directly relevant to their paper but I wished they posted code because it would actually be really neat to take the amount of time calculating gradients, subtract that from the overall runtime, and see how NumPyro’s NUTS impl compares to Stan. They use a non-recursive version of NUTS and I’ve always been curious how that performs against a set of problems.

The code of the compiler is already open source (GitHub - deepppl/stanc3: Rewriting the Stan compiler in OCaml). We are going to put the code of the benchmark open source too.

I also don’t really understand figure 10 as well. Is this Stan with the compiler they wrote?

In this figure, we are showing some Stan code extended with two new blocks (guide parameters and guide) for variational inference (VI) with explicit guide. The graphs compare the results of Stan with NUTS, Pyro with NUTS, and Stan with ADVI (without guide) vs the extended Stan with explicit guide.

Overall though the paper is neat and if they wanted to post a design doc to integrate the numpyro stuff into Stan’s compiler I think that would be cool. I think it would be nice to break it up into two seperate things, the numpyro backend and then the NN specific blocks

We are definitely happy to contribute this code back to the community! I will contact you to understand the process better. If you already want to have a look, we have created a branch pyro in our fork with just the Pyro/NumPyro backends (GitHub - deepppl/stanc3 at pyro).

4 Likes

It feels like a bit of an aside, but (for the specific benefit of @louis-mandel and @stevebronder as well as everyone else generically) our interest is focused on the fact that because Louis’ team has facilitated a backend in NumPyro, running multiple chains on each of many GPUs becomes a lot easier than it would be otherwise (and doesn’t require any changes to the Stan files). That should, we hope (and believe), make it possible to implement large-scale (aka using multiple GPUs) SMC samplers on GPUs. Our hope (and anticipation) is that we can then replace the MCMC in Stan with an SMC sampler and thereby offer substantial speed-ups relative to Stan. We’re excited about this prospect and see the work that Louis’ team has done as an important stepping stone towards an exciting future!

3 Likes

Just to clarify that someone reading this doesnt get the wrong idea.

Running “multiple chains” on many GPUs is doable in Stan as is. Or rather you can compute the entire gradient evaluation on the GPU or parts of it if some function/distribution is not supported yet (list of supported functions/distributions is here. Even in the best case, where the entire AD eval is done on the GPU we still need to copy back the parameters to run an iteration of the sampler and then copy back. In order to avoid that we would have to rewrite the sampling algorithm to run on the GPU. We might at some point try that out but its not in our immediate plans (at least not that I am aware of).

Even with the current state of GPU support in Stan we have successfully ran 100+ chains in the time of a single CPU chain. Most of this should be available in the next Stan release (but not 100% atm).

The other downside for multi chain approaches in Stan right now is that we create multiple copies of data, each chain gets its own instance. This is fine for 100MBs or GBs of data but is less great when you have 10GBs or TBs of data.

Multiple chains sharing the same instance of data would be good to have in any case. Cache and memory hierarchy friendliness is arguably an even more important reason for why we need this.

7 Likes

@rok_cesnovar: Fair point. Thanks for clarifying. As I hope you realised, I didn’t mean to imply that your work doesn’t exist, just that I think the work by @louis-mandel is (also) interesting.

5 Likes

Yes of course, definitely did not take it that way, all good.

The clarification was mostly intended for anyone skimming this at some later point in time.

7 Likes

I took the liberty to move this to a new topic as it is IMHO quite a separate discussion from the original post.

@louis-mandel: Really glad you joined the forums to give us clarifications! Thanks! I just want to reiterate the concern that it is unclear whether the wall-clock time comparison is sensible. The paper doesn’t state what was exactly compared - was it “both backends at their default settings” or “both backends to compute N iterations” or somethig else? As mentioned above, we currently believe a fair comparison should involve the effective sample size (ESS), e.g. by computing ESS/second. If the ESS/iteration ratio is not the same between the two backends, then one could be slower to compute the same amount of iterations but be faster to get a result of the same accuracy. I understand this is less of a concern if both compared algorithms are NUTS variants and thus the ESS/iteration is likely quite similar, but would still be great to see that reported!

Thanks a lot for the work you have put into this.

2 Likes

Thanks a lot for your feedback!
Looking at the ESS is a very good suggestion.

We measure the inference time with the two backends using , for each benchmark, the configuration stored in porsteriordb (num_chains, num_iterations, warmups, etc.).
We then check the results with the simple criterium on the mean and std of each parameter.

3 Likes

Welcome to the forums! I checked out some of your other really cool stuff like writing a Stan model directly in python code which transpiles to Stan in Yaps and the synchronous probabilistic language ProbZelus.

I’ve been following the Pyro project for some time and if I can stay in Stan syntax but get all the extra functionality out of NumPyro that is really exciting! Would it support the NumPyro effect handlers as well?

1 Like

Thank you for also looking up our other projects! :-)

I have personally never used effect handlers but it should work directly. My understanding of their design was to be orthogonal to the definition of the model.

Our compiler produces a simple python file containing the Pyro code. The only issue might come from the name of the sampling sites. The compiler tries to keep sensible names, in particular, the initial sampling of a parameter has the name of the parameter but the compiler also has to generate new names.