-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic JAX support #84
Changes from 12 commits
2ee6902
12b5294
6bf5dad
583f6bb
9c8bed6
6d59ae8
ce07cd9
ddb313e
6004b97
701a5ef
049d557
db667ea
aafbbaa
919ec41
fa758f7
6c338ca
244462f
bff9bf2
264e6c3
e7aff0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,9 @@ | |
|
||
import sys | ||
import math | ||
import inspect | ||
|
||
def _is_numpy_array(x): | ||
def is_numpy_array(x): | ||
# Avoid importing NumPy if it isn't already | ||
if 'numpy' not in sys.modules: | ||
return False | ||
|
@@ -20,7 +21,7 @@ def _is_numpy_array(x): | |
# TODO: Should we reject ndarray subclasses? | ||
return isinstance(x, (np.ndarray, np.generic)) | ||
|
||
def _is_cupy_array(x): | ||
def is_cupy_array(x): | ||
# Avoid importing NumPy if it isn't already | ||
if 'cupy' not in sys.modules: | ||
return False | ||
|
@@ -30,7 +31,7 @@ def _is_cupy_array(x): | |
# TODO: Should we reject ndarray subclasses? | ||
return isinstance(x, (cp.ndarray, cp.generic)) | ||
|
||
def _is_torch_array(x): | ||
def is_torch_array(x): | ||
# Avoid importing torch if it isn't already | ||
if 'torch' not in sys.modules: | ||
return False | ||
|
@@ -40,7 +41,7 @@ def _is_torch_array(x): | |
# TODO: Should we reject ndarray subclasses? | ||
return isinstance(x, torch.Tensor) | ||
|
||
def _is_dask_array(x): | ||
def is_dask_array(x): | ||
# Avoid importing dask if it isn't already | ||
if 'dask.array' not in sys.modules: | ||
return False | ||
|
@@ -49,14 +50,24 @@ def _is_dask_array(x): | |
|
||
return isinstance(x, dask.array.Array) | ||
|
||
def is_jax_array(x): | ||
# Avoid importing jax if it isn't already | ||
if 'jax' not in sys.modules: | ||
return False | ||
|
||
import jax | ||
|
||
return isinstance(x, jax.Array) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be more guarded? e.g. what if someone has a module named There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know. That sort of thing tends to just break everything anyway. I've never really felt that libraries should protect against that sort of thing. Anyway, the whole point of this function is to be guarded. It won't import There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More precisely, it won't import A similar issue exists for every other package name referenced in this module. My feeling is: it costs virtually nothing to wrap this all in an appropriate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still not so sure this is a good idea. Usually if you have that sort of thing it will be an error for a lot of things, not just array_api_compat. My worry here is that guarding isn't as straightforward as it might seem. Wrapping everything in try/except could mean we end up silencing legitimate errors. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, sounds good. |
||
|
||
def is_array_api_obj(x): | ||
""" | ||
Check if x is an array API compatible array object. | ||
""" | ||
return _is_numpy_array(x) \ | ||
or _is_cupy_array(x) \ | ||
or _is_torch_array(x) \ | ||
or _is_dask_array(x) \ | ||
return is_numpy_array(x) \ | ||
or is_cupy_array(x) \ | ||
or is_torch_array(x) \ | ||
or is_dask_array(x) \ | ||
or is_jax_array(x) \ | ||
or hasattr(x, '__array_namespace__') | ||
|
||
def _check_api_version(api_version): | ||
|
@@ -81,37 +92,43 @@ def your_function(x, y): | |
""" | ||
namespaces = set() | ||
for x in xs: | ||
if _is_numpy_array(x): | ||
if is_numpy_array(x): | ||
_check_api_version(api_version) | ||
if _use_compat: | ||
from .. import numpy as numpy_namespace | ||
namespaces.add(numpy_namespace) | ||
else: | ||
import numpy as np | ||
namespaces.add(np) | ||
elif _is_cupy_array(x): | ||
elif is_cupy_array(x): | ||
_check_api_version(api_version) | ||
if _use_compat: | ||
from .. import cupy as cupy_namespace | ||
namespaces.add(cupy_namespace) | ||
else: | ||
import cupy as cp | ||
namespaces.add(cp) | ||
elif _is_torch_array(x): | ||
elif is_torch_array(x): | ||
_check_api_version(api_version) | ||
if _use_compat: | ||
from .. import torch as torch_namespace | ||
namespaces.add(torch_namespace) | ||
else: | ||
import torch | ||
namespaces.add(torch) | ||
elif _is_dask_array(x): | ||
elif is_dask_array(x): | ||
_check_api_version(api_version) | ||
if _use_compat: | ||
from ..dask import array as dask_namespace | ||
namespaces.add(dask_namespace) | ||
else: | ||
raise TypeError("_use_compat cannot be False if input array is a dask array!") | ||
elif is_jax_array(x): | ||
_check_api_version(api_version) | ||
# jax.experimental.array_api is already an array namespace. We do | ||
# not have a wrapper submodule for it. | ||
import jax.experimental.array_api as jnp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this import go away at some point? Should we guard against that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not true that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I fixed the comment. The question still remains though? Should I add a guard here like try:
import jax.experimental.array_api as jnp
except ImportError:
import jax.numpy as jnp for a future JAX version when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think that's probably a reasonable way to future-proof this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The downside of that particular line of code is it will also pass through jax.numpy for older JAX versions that don't have jax.experimental.array_api. I like to avoid explicit version checks in this library if I can, but maybe that's the best thing to do here. Or maybe we can just change the logic to this once JAX starts to remove (deprecates?) the experimental import. |
||
namespaces.add(jnp) | ||
elif hasattr(x, '__array_namespace__'): | ||
namespaces.add(x.__array_namespace__(api_version=api_version)) | ||
else: | ||
|
@@ -136,7 +153,7 @@ def _check_device(xp, device): | |
if device not in ["cpu", None]: | ||
raise ValueError(f"Unsupported device for NumPy: {device!r}") | ||
|
||
# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray | ||
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray | ||
# or cupy.ndarray. They are not included in array objects of this library | ||
# because this library just reuses the respective ndarray classes without | ||
# wrapping or subclassing them. These helper functions can be used instead of | ||
|
@@ -156,8 +173,17 @@ def device(x: "Array", /) -> "Device": | |
out: device | ||
a ``device`` object (see the "Device Support" section of the array API specification). | ||
""" | ||
if _is_numpy_array(x): | ||
if is_numpy_array(x): | ||
return "cpu" | ||
if is_jax_array(x): | ||
# JAX has .device() as a method, but it is being deprecated so that it | ||
# can become a property, in accordance with the standard. In order for | ||
# this function to not break when JAX makes the flip, we check for | ||
# both here. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this logic seem OK? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks OK, but it will not work with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other problem here is that, in general, JAX arrays can live on multiple devices (in which case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know the answer to that. I would bring it up on the array API repo. https://github.com/data-apis/array-api/. As far as I know it hasn't really been discussed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like both forms work >>> import jax.experimental.array_api as xp
>>> x = xp.asarray([1, 2, 3])
>>> x.to_device(x.device)
Array([1, 2, 3], dtype=int32)
>>> x.to_device(x.device())
Array([1, 2, 3], dtype=int32) |
||
if inspect.ismethod(x.device): | ||
return x.device() | ||
else: | ||
return x.device | ||
return x.device | ||
|
||
# Based on cupy.array_api.Array.to_device | ||
|
@@ -225,24 +251,30 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A | |
.. note:: | ||
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. | ||
""" | ||
if _is_numpy_array(x): | ||
if is_numpy_array(x): | ||
if stream is not None: | ||
raise ValueError("The stream argument to to_device() is not supported") | ||
if device == 'cpu': | ||
return x | ||
raise ValueError(f"Unsupported device {device!r}") | ||
elif _is_cupy_array(x): | ||
elif is_cupy_array(x): | ||
# cupy does not yet have to_device | ||
return _cupy_to_device(x, device, stream=stream) | ||
elif _is_torch_array(x): | ||
elif is_torch_array(x): | ||
return _torch_to_device(x, device, stream=stream) | ||
elif _is_dask_array(x): | ||
elif is_dask_array(x): | ||
if stream is not None: | ||
raise ValueError("The stream argument to to_device() is not supported") | ||
# TODO: What if our array is on the GPU already? | ||
if device == 'cpu': | ||
return x | ||
raise ValueError(f"Unsupported device {device!r}") | ||
elif is_jax_array(x): | ||
# This import adds to_device to x | ||
import jax.experimental.array_api | ||
if device == 'cpu': | ||
device = jax.devices('cpu')[0] | ||
return x.to_device(device, stream=stream) | ||
return x.to_device(device, stream=stream) | ||
|
||
def size(x): | ||
|
@@ -253,4 +285,6 @@ def size(x): | |
return None | ||
return math.prod(x.shape) | ||
|
||
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size'] | ||
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', | ||
'to_device', 'size', 'is_numpy_array', 'is_cupy_array', | ||
'is_torch_array', 'is_dask_array', 'is_jax_array'] |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, | ||
is_dask_array, is_jax_array, is_array_api_obj, | ||
device, to_device) | ||
|
||
from ._helpers import import_ | ||
|
||
import pytest | ||
import numpy as np | ||
from numpy.testing import assert_allclose | ||
|
||
is_functions = { | ||
'numpy': 'is_numpy_array', | ||
'cupy': 'is_cupy_array', | ||
'torch': 'is_torch_array', | ||
'dask.array': 'is_dask_array', | ||
'jax.numpy': 'is_jax_array', | ||
} | ||
|
||
@pytest.mark.parametrize('library', is_functions.keys()) | ||
@pytest.mark.parametrize('func', is_functions.values()) | ||
def test_is_xp_array(library, func): | ||
lib = import_(library) | ||
is_func = globals()[func] | ||
|
||
x = lib.asarray([1, 2, 3]) | ||
|
||
assert is_func(x) == (func == is_functions[library]) | ||
|
||
assert is_array_api_obj(x) | ||
|
||
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) | ||
def test_device(library): | ||
if library == "dask.array": | ||
pytest.xfail("device() needs to be fixed for dask") | ||
|
||
if library == "jax.numpy": | ||
xp = import_('jax.experimental.array_api') | ||
else: | ||
xp = import_('array_api_compat.' + library) | ||
|
||
# We can't test much for device() and to_device() other than that | ||
# x.to_device(x.device) works. | ||
|
||
x = xp.asarray([1, 2, 3]) | ||
dev = device(x) | ||
|
||
x2 = to_device(x, dev) | ||
assert device(x) == device(x2) | ||
|
||
|
||
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) | ||
def test_to_device_host(library): | ||
# Test that "cpu" device works. Note: this isn't actually supported by the | ||
# standard yet. See https://github.com/data-apis/array-api/issues/626. | ||
|
||
# different libraries have different semantics | ||
# for DtoH transfers; ensure that we support a portable | ||
# shim for common array libs | ||
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 | ||
if library == "jax.numpy": | ||
xp = import_('jax.experimental.array_api') | ||
else: | ||
xp = import_('array_api_compat.' + library) | ||
|
||
expected = np.array([1, 2, 3]) | ||
x = xp.asarray([1, 2, 3]) | ||
x = to_device(x, "cpu") | ||
# torch will return a genuine Device object, but | ||
# the other libs will do something different with | ||
# a `device(x)` query; however, what's really important | ||
# here is that we can test portably after calling | ||
# to_device(x, "cpu") to return to host | ||
assert_allclose(x, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jakevdp, I could mostly use your review for the changes in this file.