FYI: mcx, numpyro: new NUTS implementations in Python using jax

Seems like at least two groups are implementing NUTS using Jax. (Jax seems to be the new favorite among Python automatic differentiation libraries.)

These should be added to the list—in case anyone is keeping track—of Python implementations. pymc3 is already on the list. There are probably a few more too.


I missed oryx. (It’s buried inside of tensorflow probability.)

No NUTS implementation as far as I can see, but it does have HMC using jax.


MCX author here. The PyMC devs, numpyro devs and I are currently discussing getting the HMC/NUTS implementations out of our respective libraries and move them to a separate, modeling language agnostic , repository.


that syntax tho

sigma <~ dist.Exponential(lmbda)

reminds of the spaceship operator in C++.

1 Like