From Pytorch-Pyro's website:
We’re excited to announce the release of NumPyro, a NumPy-backed Pyro using JAX for automatic differentiation and JIT compilation, with over 100x speedup for HMC and NUTS!
My questions:
- Where is the performance gain (which is sometimes 340x or 2X) of NumPyro (over Pyro) coming from exactly?
- And more importantly, why (rather, where) would I continue to use Pyro?
Extra:
- How should I view the performance and features of NumPyro compared to Tensorflow Probability, in deciding which to use where?