NumPyro vs Pyro: Why is former 100x faster and when should I use the latter?
Asked Answered
S

1

11

From Pytorch-Pyro's website:

We’re excited to announce the release of NumPyro, a NumPy-backed Pyro using JAX for automatic differentiation and JIT compilation, with over 100x speedup for HMC and NUTS!

My questions:

  1. Where is the performance gain (which is sometimes 340x or 2X) of NumPyro (over Pyro) coming from exactly?
  2. And more importantly, why (rather, where) would I continue to use Pyro?

Extra:

  1. How should I view the performance and features of NumPyro compared to Tensorflow Probability, in deciding which to use where?
Stern answered 17/5, 2020 at 3:32 Comment(2)
It looks like some of the speed improvements might have been fixing some edge cases (for NUTS at least): PR #131 - Test some edge examples from Pyro , originated from Pyro Forum - NUTS discussion. The details are most likely in their paper Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro.Albertoalberts
@MichaelJungo Could be cherry-picking, but the paper discusses HMM where NumPyro is 340X faster (while 2X on Logistic Regression). The main reason cited is its better JIT, which is kinda corroborated by this (referencing O(1) vs O(N) here).Stern
P
7

That's a good question. I just asked the same question in Pyro's dedicated forum. Here's the answer of one of their core developers: "There are many cool stuffs in Pyro that do not appear in NumPyro, for example, see Contributed code section in Pyro docs. For me, while developing, it is much easier to debug PyTorch code than Jax code (though Jax team has put much effort to help debugging in recent releases). Hence to implement a new inference algorithm, it is easier for me to work in Pyro."

Propulsion answered 30/8, 2020 at 20:58 Comment(1)
It seems that Jax is the lightest auto-differentiation library out there. So most of these improvements might center on HMC/NUTS, but not variational inference. PyTorch, conversely, builds computational graphs that might be a little bloated for HMC/NUTS efficiency but potentially offer improved SVI functionality. So, NumPyro might be ideal for traditional bayesian statistics, whereas Pyro might be ideal for Bayesian ML, Bayesian NNs, etc. Plus the researchers innovating these models tend to have DL experience, so PyTorch isn't necessarily a hindrance to exploring the library.Julianejuliann

© 2022 - 2024 — McMap. All rights reserved.