Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement overrides of NumPy's public API on JAX arrays #611

Closed
wants to merge 37 commits into from

Conversation

shoyer
Copy link
Collaborator

@shoyer shoyer commented Apr 13, 2019

__array_ufunc__ allows for writing NumPy's ufuncs, e.g., onp.sin().

__array_function__ is a new, experimental override for most other functions in NumPy public API, e.g., onp.concatenate(). It will be enabled by default in NumPy 1.17, but is also available in NumPy 1.16 if you set the environment variable NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1 before importing NumPy.

Together, these should allow users to stick with import numpy as np for use with JAX, instead of requiring import jax.numpy as np. I expect this will be particularly useful for projects that want to remain implementation agnostic, e.g., so they can write functions that will run without changes on JAX, CuPy and Dask arrays.

Note: if you want to test this out in Colab, I think you need to install the development version of NumPy (e.g., pip install -U git+https://github.com/numpy/numpy.git). As far as I can tell, it isn't possible to set an environment variable from Colab before importing NumPy.

`__array_ufunc__` allows for writing NumPy's ufuncs, e.g., `onp.sin()`.

`__array_function__` is a new, experimental override for most other functions
in NumPy public API, e.g., `onp.concatenate()`. It will be enabled by
default in NumPy 1.17, but is also available in NumPy 1.16 if you set the
environment variable `NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1` before importing
NumPy.

Together, these should allow users to stick with `import numpy as np` for use
with JAX, instead of requiring `import jax.numpy as np`. I expect this will
be particularly useful for projects that want to remain implementation
agnostic, e.g., so they can write functions that will run without changes on
JAX, CuPy and Dask arrays.

Note: if you want to test this out in Colab, I think you need to install
the development version of NumPy (e.g.,
`pip install -U git+https://github.com/numpy/numpy.git`). As far as I can
tell, it isn't possible to set an environment variable from Colab before
importing NumPy.
@shoyer
Copy link
Collaborator Author

shoyer commented Apr 13, 2019

I added __array_ufunc__ and __array_function__ to Tracer UnshapedArray, but this could still use an integration test. Any suggestions on where I should put that?

Copy link
Collaborator Author

@shoyer shoyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still a bit short on test coverage. But the neural net training example from the README now works with import numpy as np!

jax/core.py Show resolved Hide resolved
@mattjj mattjj mentioned this pull request Jun 19, 2019
@shoyer
Copy link
Collaborator Author

shoyer commented Aug 2, 2019

NumPy 1.17 is out, so these overrides will work by default now.

It would be nice merge this soonish, if only so I don't need to continue to rebase :). Also I can't wait to stop writing import jax.numpy as np!

@mattjj
Copy link
Collaborator

mattjj commented Aug 2, 2019

Thanks, Stephan!

Would we still need to import jax.numpy as np to use op-by-op JAX execution on standard numpy ndarrays? In that case I guess users will need to use more explicit jax.device_put calls (if they don’t import jax.numpy as before.

More importantly, this would change the behavior for anyone using import numpy as onp: where that could previously be used to convert to host-side ndarrays, now it might not, meaning users would need more explicit jax.device_get calls.

Those might be good things, but I just want to make sure I’m understanding. It would be a nontrivial api change.

@shoyer
Copy link
Collaborator Author

shoyer commented Aug 2, 2019

Would we still need to import jax.numpy as np to use op-by-op JAX execution on standard numpy ndarrays? In that case I guess users will need to use more explicit jax.device_put calls (if they don’t import jax.numpy as before.

Yes, that would work. More practically, the easy way to ensure functions get executed with JAX is to use jit.

More importantly, this would change the behavior for anyone using import numpy as onp: where that could previously be used to convert to host-side ndarrays, now it might not, meaning users would need more explicit jax.device_get calls.

That's right, this would be a breaking change for such code. Instead, users will need to write onp.array() explicitly to convert into NumPy arrays. It might be worth testing this change on our internal codebase to gauge its impact.

For the most part I expect this will should be fine -- JAX has mostly equivalent implementations of most commonly used NumPy functions.

One noteworthy case are functions like np.unique() that exist in NumPy but not JAX. In the current version of this PR, we raise an exception when you call these functions on JAX arrays. Instead, it would probably be more user-friendly to call jax.device_get() to load them into memory and issue a warning about functions not implemented by JAX.

@shoyer
Copy link
Collaborator Author

shoyer commented Oct 24, 2019

@mattjj and I decided not to merge this until/unless we find someone who really needs it. Please speak up in #1565 if that's you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants