Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overrides of NumPy functions on JAX arrays #1565

Open
shoyer opened this issue Oct 24, 2019 · 12 comments
Open

Overrides of NumPy functions on JAX arrays #1565

shoyer opened this issue Oct 24, 2019 · 12 comments
Labels
enhancement New feature or request

Comments

@shoyer
Copy link
Collaborator

shoyer commented Oct 24, 2019

NumPy has protocols, based on the __array_ufunc__ and __array_function__ methods, that allow for overriding what NumPy functions like np.sin() and np.concatenate when called on other array types.

In practice, this means users can write import numpy as np to get NumPy functions that work on JAX arrays instead of needing to write import jax.numpy as np.

It might make sense to implement these methods on JAX's array objects. A working prototype of this can be found in #611.

Reason to do this:

  • This would make possible to write generic code that works with NumPy/JAX/Dask/sparse/whatever, at least in simple cases: you can just use import numpy as np and it will probably work. This is particularly advantageous for third-party libraries (e.g., for projects like opt-einsum or xarray) that want to support multiple backends in a clean, composable way.
  • By opting into NumPy's API, JAX gets an override API "for free". This could be useful even if all you care about is supporting operations on JAX arrays. For example: you could write a library that wraps JAX and adds Pytorch 1.3 style named tensors.
  • JAX's JIT compilation allows for powerful "zero-cost abstraction" like C++ but in Python. There are projects like xarray that could potentially make use of this in a really compelling way, e.g., you could write a simulation with labeled multi-dimensional arrays with unit checking, without any extra performance cost!
  • More generally: it's a nice integration point with the third-party SciPy/PyData ecosystem. There's assuredly loads of other cool stuff you could do with it.

Reasons not to do this:

  • This breaks existing code that relying upon NumPy functions coercing arguments to NumPy arrays. Large projects using JAX will probably need to add some explicit calls to onp.asarray(). Implement overrides of NumPy's public API on JAX arrays #611 includes a handful of examples of this internally in JAX.
  • The implementation is rather complex and a little fragile, especially if we want to accommodate a flag that allows for switching it on and off. This imposes an additional maintenance burden on the JAX team.
  • We don't yet have any concrete examples of end-user use-cases for this functionality. It would let us easily wrap JAX with xarray, but what would that be good for?

Decision by @mattjj and myself: We're not going merge this yet, because it's not clear that anyone would even use it and it imposes a maintenance burden.

If you have compelling use-cases, please speak up. We could relatively easily make this happen, but would need someone who could commit to being a passionate user first.

@lukasheinrich
Copy link

lukasheinrich commented Mar 6, 2020

adherence to NEP13 and NEP18 would make it useful to integrate jax into projects that rely on them for portability. Specifically we're looking to integrate jax w/ scale-out systems like e.g. dask and particle physics libraries like https://github.com/scikit-hep/awkward-array. @jpivarski can probably comment better on the technical details but we'd very much be a passionate user :)

@Hoeze
Copy link

Hoeze commented Apr 14, 2020

I love the imagination of xarray with jax in the back... Would be so awesome!
Also, it's quite unfortunate that Tensorflow/Jax/... all have different APIs compared to numpy.

@sursu
Copy link
Contributor

sursu commented Apr 21, 2020

An example:

N = lambda x: stats.norm.cdf(x)

def test(a, b):
    return N((b-a)/np.sqrt(a))

Jake's function (in the mentioned issue above), being meant only for illustrative purposes,
allows me to @jaxify only the test function. This function calls N which does not use the jax.scipy.stats and therefore I will get an error if I try to compute the grad.

Would it be possible to override all the numpy and scipy instances from within the function I want to differentiate and all other methods being called from within this main function?

@cranmer
Copy link

cranmer commented May 28, 2020

In the context of a large software effort for the LHC (http://iris-hep.org) we are discussing this as @lukasheinrich mentioned above. We have jagged arrays and we have been able to override ufunc to allow numpy to run over our data structures. We would like to be able to do this with Jax.

@lukasheinrich
Copy link

as a minimal example this should work

pip install jax jaxlib numpy awkward`
python
>>> import awkward1
>>> import numpy as np
>>> import jax.numpy as jnp
>>> a = awkward1.from_iter([[1,2,3],[],[4,5]])
>>> np.power(a,2)
<Array [[1, 4, 9], [], [16, 25]] type='3 * var * int64'>
>>> jnp.power(a,2)
>>> jnp.power(a,2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/numpy/lax_numpy.py", line 532, in power
    return lax.integer_pow(x1, x2)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/lax/lax.py", line 265, in integer_pow
    return integer_pow_p.bind(x, y=y)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/core.py", line 211, in bind
    return self.impl(*args, **kwargs)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/interpreters/xla.py", line 217, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/interpreters/xla.py", line 209, in arg_spec
    aval = abstractify(x)
  File "/Users/lukasheinrich/Code/analysis/rpv1l-downstram/venv/lib/python3.7/site-packages/jax/interpreters/xla.py", line 159, in abstractify
    raise TypeError(f"No abstraction handler for type: {type(x)}")
TypeError: No abstraction handler for type: <class 'awkward1.highlevel.Array'>

the error message suggests that there are pluggable "abstraction handlers". If there iis a well defined protocol we could maybe implement one for awkward1.highlevel.Array arrays

shoyer added a commit to shoyer/jax that referenced this issue Aug 15, 2020
xref jax-ml#1565

`__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html))
is an experimental alternative to `__array_function__` and `__array_ufunc__`
for "duck array" compatibility with NumPy that promises to be much less
invasive.

Example usage:

```python
import numpy as np

def duckarray_stack(arrays):
    """This "stack" function should work with any array library, including JAX."""
    npx = np.get_array_module(*arrays)
    arrays = [npx.asarray(arr) for arr in arrays]
    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError('all input arrays must have the same shape')
    expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays]
    return npx.concatenate(expanded_arrays, axis=0)
```

Support for this protocol has *not* yet been implemented in NumPy, but it can
be tested with https://github.com/seberg/numpy-dispatch.

My reasoning for merging it into JAX (on an experimental basis with no
guarantees, of course) is that:

1. It's not invasive -- the implementation is small and self-contained.
2. No backwards compatibility issues. Unlike `__array_function__` and
   `__array_ufunc__`, `__array_module__` will always require an explicit
   opt-in by libraries that use it by calling `get_array_module()`.
2. Other NumPy developers
   [want evidence]numpy/numpy#16935 (comment))
   that this is actually feasible.
3. Scikit-Learn developers like @thomasjpfan are interested in exploring
   supporting scikit-learn on top of NumPy-like libraries like JAX, and
   experimental support for this protocol will make that easier.

Note: this PR does add `numpy-dispatch` as a optional testing requirement in
order to verify that this works. If desired, we could remove this from CI, but
installing numpy-dispatch (and its build requirement Cython) appears to only
add a few seconds of build time.
shoyer added a commit to shoyer/jax that referenced this issue Aug 15, 2020
xref jax-ml#1565

`__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html))
is an experimental alternative to `__array_function__` and `__array_ufunc__`
for "duck array" compatibility with NumPy that promises to be much less
invasive.

Example usage:

```python
import numpy as np

def duckarray_stack(arrays):
    """This "stack" function should work with any array library, including JAX."""
    npx = np.get_array_module(*arrays)
    arrays = [npx.asarray(arr) for arr in arrays]
    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError('all input arrays must have the same shape')
    expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays]
    return npx.concatenate(expanded_arrays, axis=0)
```

Support for this protocol has *not* yet been implemented in NumPy, but it can
be tested with https://github.com/seberg/numpy-dispatch.

My reasoning for merging it into JAX (on an experimental basis with no
guarantees, of course) is that:

1. It's not invasive -- the implementation is small and self-contained.
2. No backwards compatibility issues. Unlike `__array_function__` and
   `__array_ufunc__`, `__array_module__` will always require an explicit
   opt-in by libraries that use it by calling `get_array_module()`.
2. Other NumPy developers
   [want evidence](numpy/numpy#16935 (comment))
   that this is actually feasible.
3. Scikit-Learn developers like @thomasjpfan are interested in exploring
   supporting scikit-learn on top of NumPy-like libraries like JAX, and
   experimental support for this protocol will make that easier.

Note: this PR does add `numpy-dispatch` as a optional testing requirement in
order to verify that this works. If desired, we could remove this from CI, but
installing numpy-dispatch (and its build requirement Cython) appears to only
add a few seconds of build time.
jakevdp pushed a commit that referenced this issue Aug 18, 2020
* Add experimental __array_module__ method

xref #1565

`__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html))
is an experimental alternative to `__array_function__` and `__array_ufunc__`
for "duck array" compatibility with NumPy that promises to be much less
invasive.

Example usage:

```python
import numpy as np

def duckarray_stack(arrays):
    """This "stack" function should work with any array library, including JAX."""
    npx = np.get_array_module(*arrays)
    arrays = [npx.asarray(arr) for arr in arrays]
    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError('all input arrays must have the same shape')
    expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays]
    return npx.concatenate(expanded_arrays, axis=0)
```

Support for this protocol has *not* yet been implemented in NumPy, but it can
be tested with https://github.com/seberg/numpy-dispatch.

My reasoning for merging it into JAX (on an experimental basis with no
guarantees, of course) is that:

1. It's not invasive -- the implementation is small and self-contained.
2. No backwards compatibility issues. Unlike `__array_function__` and
   `__array_ufunc__`, `__array_module__` will always require an explicit
   opt-in by libraries that use it by calling `get_array_module()`.
2. Other NumPy developers
   [want evidence](numpy/numpy#16935 (comment))
   that this is actually feasible.
3. Scikit-Learn developers like @thomasjpfan are interested in exploring
   supporting scikit-learn on top of NumPy-like libraries like JAX, and
   experimental support for this protocol will make that easier.

Note: this PR does add `numpy-dispatch` as a optional testing requirement in
order to verify that this works. If desired, we could remove this from CI, but
installing numpy-dispatch (and its build requirement Cython) appears to only
add a few seconds of build time.

* don't explicitly list cython

* remove UnshpaedArray from _JAX_ARRAY_TYPES

* Remove incorrect note about metaclasses

* remove unnecessary numpy_dispatch.ensure_dispatching()
@mhlr
Copy link

mhlr commented Sep 7, 2020

If this automated or at least simplified postin sckiit to JAX this would be huge!

@peterdsharpe
Copy link

peterdsharpe commented May 5, 2022

Edit: nevermind this comment! I updated JAX to find that __array_module__ has been implemented. Thank you!

@shoyer
Copy link
Collaborator Author

shoyer commented May 11, 2022

Edit: nevermind this comment! I updated JAX to find that __array_module__ has been implemented. Thank you!

JAX has __array_module__, but I don't think NEP 37 is ever going to be accepted. NEP 47 (__array_namespace__ / array API standard) has much more momentum behind it, e.g., a PyTorch implementation.

@raj-magesh
Copy link

I'm curious if NEP 47 is supported (or planned) for JAX. It would be nice to transparently use xarray over Jax primitives.

@NeilGirdhar
Copy link
Contributor

@raj-magesh #18353

@raj-magesh
Copy link

That's excellent, thank you! Looks like it's shaping up brilliantly. I'm especially happy that the linear algebra primitives are almost all done!

@NeilGirdhar
Copy link
Contributor

@raj-magesh I'm excited too! The Jax team are finishing it so fast.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

9 participants