Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple tensor backends via NEP-47 (Python array API standard) #1083

Open
learning-chip opened this issue Jul 4, 2021 · 6 comments

Comments

@learning-chip
Copy link

The recently proposed NEP-47 attempts to unify the APIs of various tensor frameworks (NumPy, Tensorflow, PyTorch, Dask, JAX, CuPy, MXNet, etc.), via the Python array API standard.

It is a much more compact version of the original NumPy APIs, removing unnecessary functions that are not friendly to heterogenous hardware like GPUs.

e7381d379fb780215320fa0ec9a63013159e96b4

Since NumPyro is using JAX as backend, whose APIs closely match NumPy, it should be quite doable to adopt NEP-47 for multi-backend support?

Related discussion:

@learning-chip
Copy link
Author

learning-chip commented Jul 4, 2021

I quickly browsed through NumPyro source code to locate JAX-heavy code. For example, the most frequently used APIs include:

  • numpyro.sample
  • numpyro.infer.{MCMC, NUTS}
  • numpyro.distributions.{Normal, Exponential}
  1. sample is defined in numpyro/primitives.py, relying on jax.random.{randint, choice, PRNGKey} and other small ops.

def sample(
name, fn, obs=None, rng_key=None, sample_shape=(), infer=None, obs_mask=None
):
"""
Returns a random sample from the stochastic function `fn`. This can have
additional side effects when wrapped inside effect handlers like
:class:`~numpyro.handlers.substitute`.

  1. infer.{MCMC, NUTS} are defined in numpyro/infer/{mcmc.py, hmc.py}, relying on JAX-specific utils like jit, vmap, pmap, tree_util.tree_map, and probably grad.

class MCMC(object):
"""
Provides access to Markov Chain Monte Carlo inference algorithms in NumPyro.

class NUTS(HMC):
"""
Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS)
with adaptive path length and mass matrix adaptation.

  1. distributions.{Normal, Exponential} are defined in numpyro/distributions/continuous.py, relying on JAX's linear solvers and special functions:
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.scipy.special import betainc, expit, gammaln, logit, multigammaln, ndtr, ndtri

All of the above should be relatively easy to implement, using NEP-47 core APIs + a few framework-specific utilities (MXNet, CuPy, ChainerX, etc.). The NumPy frontend can be lowered to a variety of IRs like tvm.relay and MLIR, which can support a diverse set of hardware.

@eb8680
Copy link
Member

eb8680 commented Jul 4, 2021

Hi @learning-chip, thanks for the suggestion. This seems like an issue best resolved at the level of PyTorch, JAX, and TensorFlow, none of which are fully compatible with NEP-47 yet (to the best of my knowledge, and as far as I can tell from the discussion in data-apis/array-api#1). An additional complication is that only these three frameworks (again, to the best of my knowledge) have both full-featured automatic differentiation and a distributions library with full broadcasting semantics and reparameterized samplers, and there are already large bodies of functionally similar idiomatic PyTorch and TensorFlow code in Pyro and TFP respectively so it's not clear how much users might benefit from any attempts to support other backends in NumPyro.

There was some discussion about __array_function__ in JAX in jax-ml/jax#1565. You may also be interested in our less comprehensive attempts at sharing code/interfaces across backends in Funsor (mostly for inference code) and pyro-api (mostly for model code).

@learning-chip
Copy link
Author

learning-chip commented Jul 5, 2021

none of which are fully compatible with NEP-47 yet

This is correct, at this moment. My question is, once those frameworks get fully-compatible with NEP-47, what would be the amount of effort to add them as new backends for NumPyro? In terms of missing features to support, lines of code, or people * month, for example. This will affect whether the framework developer team should design their own Bayesian learning library, or simply reuse (Num)Pyro.

There seems to be more and more HPC/AI frameworks providing NumPy-like API. Nvidia has recently open-sourced Legate NumPy; and DaCe is another framework that excels at optimizing HPC kernels and runs on FPGAs.

only these three frameworks (again, to the best of my knowledge) have both full-featured automatic differentiation

I recall that both MXNet and ONNX are interested in the Python Data API. Maybe @rgommers and @szha are the right people to ask.

full broadcasting semantics and reparameterized samplers

Broadcasting is a key feature in Python array API standard. For "reparameterized samplers", could you elaborate more on the exact functionalities?

how much users might benefit from any attempts to support other backends in NumPyro.

From my hardware system background, I think a big difference between those AI frameworks, is their compiler & hardware support. Their software functionalities are indeed getting similar -- all providing autodiff, AI model zoo, distributed training, etc. But the compile chain is quite different: TensorFlow -> XLA -> TPU, and ONNX -> MicroTVM -> IoT/edge devices, are some unique examples. Say if Bayesian AI models get popular in autonomous vehicles (I think they already are), then you might need MicroTVM for edge deployment, a case that NumPyro+JAX does not currently support.

@szha
Copy link

szha commented Jul 5, 2021

@learning-chip thanks for the ping and for bringing up the array API standard. The work for adopting the standard are in progress for MXNet and ONNX. Once array libraries finish implementing the compatibility with the standard it should indeed make it straightforward to switch array backends as long as the implementation only relies on the operations defined in the standard.

@eb8680
Copy link
Member

eb8680 commented Jul 6, 2021

@learning-chip to be clear, Pyro and NumPyro rely crucially on advanced features of PyTorch and JAX (especially higher-order forward- and reverse-mode AD and a large library of probability distributions with Tensorflow-style shape semantics and reparameterized samplers) that are outside the scope of NEP-47.

As far as I know, only PyTorch, JAX and TensorFlow implement all of the relevant features, and there are many small differences across each framework's implementation and API that would make unifying them difficult and unpleasant. Unless the maintainers of these frameworks plan on standardizing all of these features as well, it's unlikely that we will be able to make our existing codebases in Pyro and NumPyro backend-independent via NEP-47.

Like most open source projects, we are a small team and do not have the developer bandwidth to reimplement and standardize these features ourselves or refactor thousands of lines of backend-specific test code in Pyro and NumPyro, although we certainly support the high-level goal of a backend-independent array programming ecosystem. Barring significant new external contributions, our efforts at backend-independence will probably remain restricted for the time being to the pyro-api and funsor libraries I pointed out above, which are also better targets for future NEP-47 support.

I think a big difference between those AI frameworks, is their compiler & hardware support ... you might need MicroTVM for edge deployment, a case that NumPyro+JAX does not currently support.

I would bet on XLA and TVM adding more backend support and even achieving some degree of interoperability as a near-term path to this sort of thing before than the higher-level software ecosystem adopts NEP-47 en masse, but if you or anyone reading this thread know of users who want to deploy existing Pyro/NumPyro models but can't because of missing hardware backend support, please tell them to contact us!

@learning-chip
Copy link
Author

especially higher-order forward- and reverse-mode AD and a large library of probability distributions with Tensorflow-style shape semantics and reparameterized samplers

Thanks, this is very useful information. I will read that paper carefully, including this doc:
https://www.tensorflow.org/probability/examples/TensorFlow_Distributions_Tutorial

Unless the maintainers of these frameworks plan on standardizing all of these features as well, it's unlikely that we will be able to make our existing codebases in Pyro and NumPyro backend-independent via NEP-47.

I totally understand and agree -- I just need a list of "missing features" (beyond NEP-47) that other frameworks should support, if they want to be plugged into NumPyro's backend.

Barring significant new external contributions, our efforts at backend-independence will probably remain restricted for the time being to the pyro-api and funsor libraries I pointed out above, which are also better targets for future NEP-47 support.

This sounds a very reasonable target to me.

I would bet on XLA and TVM adding more backend support and even achieving some degree of interoperability as a near-term path to this sort of thing before than the higher-level software ecosystem adopts NEP-47 en masse

XLA and TVM could be a thinner layer than NEP-47 & Python Array API, indeed. The path of interoperability is unclear to me, though. See Relay MLIR Frontend discussions.

XLA is moving towards MLIR-HLO, which is a significant & long-term change in my opinion (MLIR is a huge beast). Thus I would not expect near-term improvements 😂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants