-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic JAX support #84
Changes from 4 commits
2ee6902
12b5294
6bf5dad
583f6bb
9c8bed6
6d59ae8
ce07cd9
ddb313e
6004b97
701a5ef
049d557
db667ea
aafbbaa
919ec41
fa758f7
6c338ca
244462f
bff9bf2
264e6c3
e7aff0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,9 @@ | |
|
||
import sys | ||
import math | ||
import inspect | ||
|
||
def _is_numpy_array(x): | ||
def is_numpy_array(x): | ||
# Avoid importing NumPy if it isn't already | ||
if 'numpy' not in sys.modules: | ||
return False | ||
|
@@ -20,7 +21,7 @@ def _is_numpy_array(x): | |
# TODO: Should we reject ndarray subclasses? | ||
return isinstance(x, (np.ndarray, np.generic)) | ||
|
||
def _is_cupy_array(x): | ||
def is_cupy_array(x): | ||
# Avoid importing NumPy if it isn't already | ||
if 'cupy' not in sys.modules: | ||
return False | ||
|
@@ -30,7 +31,7 @@ def _is_cupy_array(x): | |
# TODO: Should we reject ndarray subclasses? | ||
return isinstance(x, (cp.ndarray, cp.generic)) | ||
|
||
def _is_torch_array(x): | ||
def is_torch_array(x): | ||
# Avoid importing torch if it isn't already | ||
if 'torch' not in sys.modules: | ||
return False | ||
|
@@ -40,7 +41,7 @@ def _is_torch_array(x): | |
# TODO: Should we reject ndarray subclasses? | ||
return isinstance(x, torch.Tensor) | ||
|
||
def _is_dask_array(x): | ||
def is_dask_array(x): | ||
# Avoid importing dask if it isn't already | ||
if 'dask.array' not in sys.modules: | ||
return False | ||
|
@@ -49,14 +50,24 @@ def _is_dask_array(x): | |
|
||
return isinstance(x, dask.array.Array) | ||
|
||
def is_jax_array(x): | ||
# Avoid importing jax if it isn't already | ||
if 'jax' not in sys.modules: | ||
return False | ||
|
||
import jax.numpy | ||
|
||
return isinstance(x, jax.numpy.ndarray) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def is_array_api_obj(x): | ||
""" | ||
Check if x is an array API compatible array object. | ||
""" | ||
return _is_numpy_array(x) \ | ||
or _is_cupy_array(x) \ | ||
or _is_torch_array(x) \ | ||
or _is_dask_array(x) \ | ||
return is_numpy_array(x) \ | ||
or is_cupy_array(x) \ | ||
or is_torch_array(x) \ | ||
or is_dask_array(x) \ | ||
or is_jax_array(x) \ | ||
or hasattr(x, '__array_namespace__') | ||
|
||
def _check_api_version(api_version): | ||
|
@@ -81,37 +92,44 @@ def your_function(x, y): | |
""" | ||
namespaces = set() | ||
for x in xs: | ||
if _is_numpy_array(x): | ||
if is_numpy_array(x): | ||
_check_api_version(api_version) | ||
if _use_compat: | ||
from .. import numpy as numpy_namespace | ||
namespaces.add(numpy_namespace) | ||
else: | ||
import numpy as np | ||
namespaces.add(np) | ||
elif _is_cupy_array(x): | ||
elif is_cupy_array(x): | ||
_check_api_version(api_version) | ||
if _use_compat: | ||
from .. import cupy as cupy_namespace | ||
namespaces.add(cupy_namespace) | ||
else: | ||
import cupy as cp | ||
namespaces.add(cp) | ||
elif _is_torch_array(x): | ||
elif is_torch_array(x): | ||
_check_api_version(api_version) | ||
if _use_compat: | ||
from .. import torch as torch_namespace | ||
namespaces.add(torch_namespace) | ||
else: | ||
import torch | ||
namespaces.add(torch) | ||
elif _is_dask_array(x): | ||
elif is_dask_array(x): | ||
_check_api_version(api_version) | ||
if _use_compat: | ||
from ..dask import array as dask_namespace | ||
namespaces.add(dask_namespace) | ||
else: | ||
raise TypeError("_use_compat cannot be False if input array is a dask array!") | ||
elif is_jax_array(x): | ||
_check_api_version(api_version) | ||
# jax.numpy is already an array namespace, but requires this | ||
# side-effecting import for __array_namespace__ and some other | ||
# things to be defined. | ||
import jax.experimental.array_api as jnp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this import go away at some point? Should we guard against that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not true that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I fixed the comment. The question still remains though? Should I add a guard here like try:
import jax.experimental.array_api as jnp
except ImportError:
import jax.numpy as jnp for a future JAX version when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think that's probably a reasonable way to future-proof this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The downside of that particular line of code is it will also pass through jax.numpy for older JAX versions that don't have jax.experimental.array_api. I like to avoid explicit version checks in this library if I can, but maybe that's the best thing to do here. Or maybe we can just change the logic to this once JAX starts to remove (deprecates?) the experimental import. |
||
namespaces.add(jnp) | ||
elif hasattr(x, '__array_namespace__'): | ||
namespaces.add(x.__array_namespace__(api_version=api_version)) | ||
else: | ||
|
@@ -156,8 +174,17 @@ def device(x: "Array", /) -> "Device": | |
out: device | ||
a ``device`` object (see the "Device Support" section of the array API specification). | ||
""" | ||
if _is_numpy_array(x): | ||
if is_numpy_array(x): | ||
return "cpu" | ||
if is_jax_array(x): | ||
# JAX has .device() as a method, but it is being deprecated so that it | ||
# can become a property, in accordance with the standard. In order for | ||
# this function to not break when JAX makes the flip, we check for | ||
# both here. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this logic seem OK? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks OK, but it will not work with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other problem here is that, in general, JAX arrays can live on multiple devices (in which case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know the answer to that. I would bring it up on the array API repo. https://github.com/data-apis/array-api/. As far as I know it hasn't really been discussed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like both forms work >>> import jax.experimental.array_api as xp
>>> x = xp.asarray([1, 2, 3])
>>> x.to_device(x.device)
Array([1, 2, 3], dtype=int32)
>>> x.to_device(x.device())
Array([1, 2, 3], dtype=int32) |
||
if inspect.ismethod(x.device): | ||
return x.device() | ||
else: | ||
return x.device | ||
return x.device | ||
|
||
# Based on cupy.array_api.Array.to_device | ||
|
@@ -204,6 +231,12 @@ def _torch_to_device(x, device, /, stream=None): | |
raise NotImplementedError | ||
return x.to(device) | ||
|
||
def _jax_to_device(x, device, /, stream=None): | ||
import jax | ||
if stream is not None: | ||
raise NotImplementedError | ||
return jax.device_put(x, device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this helper function for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, |
||
|
||
def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": | ||
""" | ||
Copy the array from the device on which it currently resides to the specified ``device``. | ||
|
@@ -225,24 +258,26 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A | |
.. note:: | ||
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. | ||
""" | ||
if _is_numpy_array(x): | ||
if is_numpy_array(x): | ||
if stream is not None: | ||
raise ValueError("The stream argument to to_device() is not supported") | ||
if device == 'cpu': | ||
return x | ||
raise ValueError(f"Unsupported device {device!r}") | ||
elif _is_cupy_array(x): | ||
elif is_cupy_array(x): | ||
# cupy does not yet have to_device | ||
return _cupy_to_device(x, device, stream=stream) | ||
elif _is_torch_array(x): | ||
elif is_torch_array(x): | ||
return _torch_to_device(x, device, stream=stream) | ||
elif _is_dask_array(x): | ||
elif is_dask_array(x): | ||
if stream is not None: | ||
raise ValueError("The stream argument to to_device() is not supported") | ||
# TODO: What if our array is on the GPU already? | ||
if device == 'cpu': | ||
return x | ||
raise ValueError(f"Unsupported device {device!r}") | ||
elif is_jax_array(x): | ||
return _jax_to_device(x, device, stream=stream) | ||
return x.to_device(device, stream=stream) | ||
|
||
def size(x): | ||
|
@@ -253,4 +288,6 @@ def size(x): | |
return None | ||
return math.prod(x.shape) | ||
|
||
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size'] | ||
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', | ||
'to_device', 'size', 'is_numpy_array', 'is_cupy_array', | ||
'is_torch_array', 'is_dask_array', 'is_jax_array'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, | ||
is_dask_array, is_jax_array, is_array_api_obj) | ||
|
||
from ._helpers import import_ | ||
|
||
import pytest | ||
|
||
is_functions = { | ||
'numpy': 'is_numpy_array', | ||
'cupy': 'is_cupy_array', | ||
'torch': 'is_torch_array', | ||
'dask.array': 'is_dask_array', | ||
'jax.numpy': 'is_jax_array', | ||
} | ||
|
||
@pytest.mark.parametrize('library', is_functions.keys()) | ||
@pytest.mark.parametrize('func', is_functions.values()) | ||
def test_is_xp_array(library, func): | ||
lib = import_(library) | ||
is_func = globals()[func] | ||
|
||
x = lib.asarray([1, 2, 3]) | ||
|
||
assert is_func(x) == (func == is_functions[library]) | ||
|
||
assert is_array_api_obj(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jakevdp, I could mostly use your review for the changes in this file.