Using JAX as an AD backend

I’m looking through the JAX (I’m going to go lowercase jax for the rest bc I’m on mobile) docs about the foreign function interface (ffi) and I see this about cuda GPU Foreign function interface (FFI) — JAX documentation.

It appears that once the function is defined in C++ it’s just template specifications to get the function to be autodiffed on either CPU or GPU.

I would love if we had enough devs and resources to expand Stan ourselves but realistically piggybacking on something like jax for gpu is more realistic.

If I squint hard enough it seems like we could have our cake and it eat it too. Meaning, we can keep all the custom derivatives we have and use jax for parallelized GPU autodiff. If that worked we have a number of options for our code base. We could have a version of Stan that uses a backend=“jax”. This version would remove all the current GPU code, have all the jax templates (maybe there’s some magic we can do to auto template everything), and allow using native jax numpy functions.

It seems like this would be a spinoff that is akin to melding bridgestan, Stan-math, stanc3 (to allow jnp functions in Stan code), and jax ffi.

@WardBrian @stevebronder @Bob_Carpenter @Stan_Development_Team

I’m not sure I follow what you’re hoping for here. Would stan still be running its algorithms but be calling JAX? Or are you basically hoping we can transpile a Stan model to something JAX can use?

I’m very interested in helping to implement some experimental trials. I found some discussions that might be relevant:
How do we do move parts of the autodiff stack to the GPU? · Issue #1639 · stan-dev/math
Parallel autodiff v2 - Developers - The Stan Forums
Parallel autodiff v3 - Developers - The Stan Forums

I think it would make sense to contact some jax developers, do we know any?

I’m jumping on the idea of Stan being modular or just a piece of a larger ecosystem. I love the Stan language and there’s tons of great code in Stan-math so I’m thinking about how we can reuse this as much as possible but latch onto larger projects like Jax to leverage what they’re good at.

To be more specific, my thoughts were agnostic toward he inference algoritm. It could be the one in Stan or something else, just as bridgestan allows. What I’m interested in is reusing our math library, as in the external cpp framework to plugin to Jax which allows GPU and parallelized ad, while also getting access to everything else in Jax. The inference could also be offloaded to jax or stay in Stan.

That’s a good idea but I don’t know any

Like Brian, I’m unclear on what you’re suggesting. Are you suggesting that we have JAX call our GPU code? That’s what it looks like their FFI does. I would think we would want to do the opposite, which is transpire Stan into JAX so we can use their autodiff on GPU.

The JAX devs are an email or issue post away. We’ve been in contact over expanding Inference Gym. One used to be at Flatiron until recently—Dan Foreman-Mackey (also the person behind the emcee package in Python that everyone uses in astro). And we know lots of adjacent folks on the TensorFlow Probability team.

Congrats @Bob_Carpenter on your wonderful post!

JAX à la Stan

We are running very expensive models in Stan for some time now, and we would love to use Jax acceleration within the Stan ecosystem.

On a personal note, I think more and more about probabilistic modelling as a specialised and expensive operation that is practically (for applications that matter for research) closer to neural networks in terms of computing demand than to an lm(). In my field (comp bio), if we propose expensive reusable models that can do much more than frequentist ones, the only way to become mainstream would be to go to GPU.

tagging @Yingnan_Gao

densejax comes a bit of out of the bushes in the article, but maybe we hear more of it soon.

+1 on putting this together. A companion topic I’m interested in is the autodiff capabilities of JAX in a model. Before I go trying to derive the “co-area Jacobian” or functions with complicated derivatives (specifically adjoint derivative), is it decently performant and easy to use JAX’s autodiff system? Similarly, is it easy to access derivative information within a model (or second derivative information or forming the Hessian) is made faster and easier using JAX? It seems like it should be.

Yes, looking forward to trying it out. We have some beefy gpus that I’ve always wanted to use more fully.