-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement overrides of NumPy's public API on JAX arrays #611
Conversation
`__array_ufunc__` allows for writing NumPy's ufuncs, e.g., `onp.sin()`. `__array_function__` is a new, experimental override for most other functions in NumPy public API, e.g., `onp.concatenate()`. It will be enabled by default in NumPy 1.17, but is also available in NumPy 1.16 if you set the environment variable `NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1` before importing NumPy. Together, these should allow users to stick with `import numpy as np` for use with JAX, instead of requiring `import jax.numpy as np`. I expect this will be particularly useful for projects that want to remain implementation agnostic, e.g., so they can write functions that will run without changes on JAX, CuPy and Dask arrays. Note: if you want to test this out in Colab, I think you need to install the development version of NumPy (e.g., `pip install -U git+https://github.com/numpy/numpy.git`). As far as I can tell, it isn't possible to set an environment variable from Colab before importing NumPy.
I added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still a bit short on test coverage. But the neural net training example from the README now works with import numpy as np
!
NumPy 1.17 is out, so these overrides will work by default now. It would be nice merge this soonish, if only so I don't need to continue to rebase :). Also I can't wait to stop writing |
Thanks, Stephan! Would we still need to More importantly, this would change the behavior for anyone using Those might be good things, but I just want to make sure I’m understanding. It would be a nontrivial api change. |
Yes, that would work. More practically, the easy way to ensure functions get executed with JAX is to use
That's right, this would be a breaking change for such code. Instead, users will need to write For the most part I expect this will should be fine -- JAX has mostly equivalent implementations of most commonly used NumPy functions. One noteworthy case are functions like |
__array_ufunc__
allows for writing NumPy's ufuncs, e.g.,onp.sin()
.__array_function__
is a new, experimental override for most other functions in NumPy public API, e.g.,onp.concatenate()
. It will be enabled by default in NumPy 1.17, but is also available in NumPy 1.16 if you set the environment variableNUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1
before importing NumPy.Together, these should allow users to stick with
import numpy as np
for use with JAX, instead of requiringimport jax.numpy as np
. I expect this will be particularly useful for projects that want to remain implementation agnostic, e.g., so they can write functions that will run without changes on JAX, CuPy and Dask arrays.Note: if you want to test this out in Colab, I think you need to install the development version of NumPy (e.g.,
pip install -U git+https://github.com/numpy/numpy.git
). As far as I can tell, it isn't possible to set an environment variable from Colab before importing NumPy.