Seems like at least two groups are implementing NUTS using Jax. (Jax seems to be the new favorite among Python automatic differentiation libraries.)
- mcx, https://github.com/rlouf/mcx
- numpyro, https://github.com/pyro-ppl/numpyro (a rewrite of pyro to use jax, I think)
These should be added to the list—in case anyone is keeping track—of Python implementations. pymc3 is already on the list. There are probably a few more too.