The lightweight probabilistic programming library NumPyro provides a NumPy backend for Pyro (ascl:2110.016). It relies on JAX for automatic differentiation and JIT compilation to GPU/CPU. The code focuses on providing a flexible substrate for users to build on, including Pyro Primitives, inference algorithms with a particular focus on MCMC algorithms such as Hamiltonian Monte Carlo, and distribution classes, constraints and bijective transforms. NumPyro also provides effect-handlers that can be extended to implement custom inference algorithms and inference utilities.
Please see citation information here: https://github.com/pyro-ppl/numpyro