FYI: mcx, numpyro: new NUTS implementations in Python using jax

Seems like at least two groups are implementing NUTS using Jax. (Jax seems to be the new favorite among Python automatic differentiation libraries.)

These should be added to the list—in case anyone is keeping track—of Python implementations. pymc3 is already on the list. There are probably a few more too.

4 Likes

I missed oryx. (It’s buried inside of tensorflow probability.)

No NUTS implementation as far as I can see, but it does have HMC using jax.

2 Likes

MCX author here. The PyMC devs, numpyro devs and I are currently discussing getting the HMC/NUTS implementations out of our respective libraries and move them to a separate, modeling language agnostic , repository.

3 Likes

that syntax tho

sigma <~ dist.Exponential(lmbda)

reminds of the spaceship operator in C++.

1 Like

Hi, I am looking for a jax package doing NUTS.

Do you know mcx, numpyro, which one is more sophisticated for use now?

If using a JAX-based PPL is important to you I would choose Numpyro.