-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multiple tensor backends via NEP-47 (Python array API standard) #1083
Comments
I quickly browsed through NumPyro source code to locate JAX-heavy code. For example, the most frequently used APIs include:
Lines 104 to 110 in f8f482a
Lines 198 to 200 in f8f482a
Lines 764 to 767 in f8f482a
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.scipy.special import betainc, expit, gammaln, logit, multigammaln, ndtr, ndtri All of the above should be relatively easy to implement, using NEP-47 core APIs + a few framework-specific utilities (MXNet, CuPy, ChainerX, etc.). The NumPy frontend can be lowered to a variety of IRs like |
Hi @learning-chip, thanks for the suggestion. This seems like an issue best resolved at the level of PyTorch, JAX, and TensorFlow, none of which are fully compatible with NEP-47 yet (to the best of my knowledge, and as far as I can tell from the discussion in data-apis/array-api#1). An additional complication is that only these three frameworks (again, to the best of my knowledge) have both full-featured automatic differentiation and a distributions library with full broadcasting semantics and reparameterized samplers, and there are already large bodies of functionally similar idiomatic PyTorch and TensorFlow code in Pyro and TFP respectively so it's not clear how much users might benefit from any attempts to support other backends in NumPyro. There was some discussion about |
This is correct, at this moment. My question is, once those frameworks get fully-compatible with NEP-47, what would be the amount of effort to add them as new backends for NumPyro? In terms of There seems to be more and more HPC/AI frameworks providing NumPy-like API. Nvidia has recently open-sourced Legate NumPy; and DaCe is another framework that excels at optimizing HPC kernels and runs on FPGAs.
I recall that both MXNet and ONNX are interested in the Python Data API. Maybe @rgommers and @szha are the right people to ask.
Broadcasting is a key feature in Python array API standard. For "reparameterized samplers", could you elaborate more on the exact functionalities?
From my hardware system background, I think a big difference between those AI frameworks, is their compiler & hardware support. Their software functionalities are indeed getting similar -- all providing autodiff, AI model zoo, distributed training, etc. But the compile chain is quite different: TensorFlow -> XLA -> TPU, and ONNX -> MicroTVM -> IoT/edge devices, are some unique examples. Say if Bayesian AI models get popular in autonomous vehicles (I think they already are), then you might need MicroTVM for edge deployment, a case that NumPyro+JAX does not currently support. |
@learning-chip thanks for the ping and for bringing up the array API standard. The work for adopting the standard are in progress for MXNet and ONNX. Once array libraries finish implementing the compatibility with the standard it should indeed make it straightforward to switch array backends as long as the implementation only relies on the operations defined in the standard. |
@learning-chip to be clear, Pyro and NumPyro rely crucially on advanced features of PyTorch and JAX (especially higher-order forward- and reverse-mode AD and a large library of probability distributions with Tensorflow-style shape semantics and reparameterized samplers) that are outside the scope of NEP-47. As far as I know, only PyTorch, JAX and TensorFlow implement all of the relevant features, and there are many small differences across each framework's implementation and API that would make unifying them difficult and unpleasant. Unless the maintainers of these frameworks plan on standardizing all of these features as well, it's unlikely that we will be able to make our existing codebases in Pyro and NumPyro backend-independent via NEP-47. Like most open source projects, we are a small team and do not have the developer bandwidth to reimplement and standardize these features ourselves or refactor thousands of lines of backend-specific test code in Pyro and NumPyro, although we certainly support the high-level goal of a backend-independent array programming ecosystem. Barring significant new external contributions, our efforts at backend-independence will probably remain restricted for the time being to the
I would bet on XLA and TVM adding more backend support and even achieving some degree of interoperability as a near-term path to this sort of thing before than the higher-level software ecosystem adopts NEP-47 en masse, but if you or anyone reading this thread know of users who want to deploy existing Pyro/NumPyro models but can't because of missing hardware backend support, please tell them to contact us! |
Thanks, this is very useful information. I will read that paper carefully, including this doc:
I totally understand and agree -- I just need a list of "missing features" (beyond NEP-47) that other frameworks should support, if they want to be plugged into NumPyro's backend.
This sounds a very reasonable target to me.
XLA and TVM could be a thinner layer than NEP-47 & Python Array API, indeed. The path of interoperability is unclear to me, though. See Relay MLIR Frontend discussions. XLA is moving towards MLIR-HLO, which is a significant & long-term change in my opinion (MLIR is a huge beast). Thus I would not expect near-term improvements 😂 |
The recently proposed NEP-47 attempts to unify the APIs of various tensor frameworks (NumPy, Tensorflow, PyTorch, Dask, JAX, CuPy, MXNet, etc.), via the Python array API standard.
It is a much more compact version of the original NumPy APIs, removing unnecessary functions that are not friendly to heterogenous hardware like GPUs.
Since NumPyro is using JAX as backend, whose APIs closely match NumPy, it should be quite doable to adopt NEP-47 for multi-backend support?
Related discussion:
The text was updated successfully, but these errors were encountered: