Using JAX as an AD backend

I’m looking through the JAX (I’m going to go lowercase jax for the rest bc I’m on mobile) docs about the foreign function interface (ffi) and I see this about cuda GPU Foreign function interface (FFI) — JAX documentation.

It appears that once the function is defined in C++ it’s just template specifications to get the function to be autodiffed on either CPU or GPU.

I would love if we had enough devs and resources to expand Stan ourselves but realistically piggybacking on something like jax for gpu is more realistic.

If I squint hard enough it seems like we could have our cake and it eat it too. Meaning, we can keep all the custom derivatives we have and use jax for parallelized GPU autodiff. If that worked we have a number of options for our code base. We could have a version of Stan that uses a backend=“jax”. This version would remove all the current GPU code, have all the jax templates (maybe there’s some magic we can do to auto template everything), and allow using native jax numpy functions.

It seems like this would be a spinoff that is akin to melding bridgestan, Stan-math, stanc3 (to allow jnp functions in Stan code), and jax ffi.

@WardBrian @stevebronder @Bob_Carpenter @Stan_Development_Team

I’m not sure I follow what you’re hoping for here. Would stan still be running its algorithms but be calling JAX? Or are you basically hoping we can transpile a Stan model to something JAX can use?

I’m very interested in helping to implement some experimental trials. I found some discussions that might be relevant:
How do we do move parts of the autodiff stack to the GPU? · Issue #1639 · stan-dev/math
Parallel autodiff v2 - Developers - The Stan Forums
Parallel autodiff v3 - Developers - The Stan Forums

1 Like

I think it would make sense to contact some jax developers, do we know any?

I’m jumping on the idea of Stan being modular or just a piece of a larger ecosystem. I love the Stan language and there’s tons of great code in Stan-math so I’m thinking about how we can reuse this as much as possible but latch onto larger projects like Jax to leverage what they’re good at.

To be more specific, my thoughts were agnostic toward he inference algoritm. It could be the one in Stan or something else, just as bridgestan allows. What I’m interested in is reusing our math library, as in the external cpp framework to plugin to Jax which allows GPU and parallelized ad, while also getting access to everything else in Jax. The inference could also be offloaded to jax or stay in Stan.

1 Like

That’s a good idea but I don’t know any

Like Brian, I’m unclear on what you’re suggesting. Are you suggesting that we have JAX call our GPU code? That’s what it looks like their FFI does. I would think we would want to do the opposite, which is transpire Stan into JAX so we can use their autodiff on GPU.

The JAX devs are an email or issue post away. We’ve been in contact over expanding Inference Gym. One used to be at Flatiron until recently—Dan Foreman-Mackey (also the person behind the emcee package in Python that everyone uses in astro). And we know lots of adjacent folks on the TensorFlow Probability team.

1 Like