diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 71083fbc..2877bf06 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest numpy torch dask[array] + python -m pip install pytest numpy torch dask[array] jax[cpu] - name: Run Tests run: | diff --git a/README.md b/README.md index 2c0ce59a..5be86271 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is a small wrapper around common array libraries that is compatible with the [Array API standard](https://data-apis.org/array-api/latest/). Currently, -NumPy, CuPy, and PyTorch are supported. If you want support for other array +NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array libraries, or if you encounter any issues, please [open an issue](https://github.com/data-apis/array-api-compat/issues). @@ -56,7 +56,17 @@ import array_api_compat.cupy as cp import array_api_compat.torch as torch ``` -Each will include all the functions from the normal NumPy/CuPy/PyTorch +```py +import array_api_compat.dask as da +``` + +> [!NOTE] +> There is no `array_api_compat.jax` submodule. JAX support is contained +> in JAX itself in the `jax.experimental.array_api` module. array-api-compat simply +> wraps that submodule. The main JAX support in this module consists of +> supporting it in the [helper functions](#helper-functions) defined below. + +Each will include all the functions from the normal NumPy/CuPy/PyTorch/dask.array namespace, except that functions that are part of the array API are wrapped so that they have the correct array API behavior. In each case, the array object used will be the same array object from the wrapped library. @@ -99,6 +109,11 @@ part of the specification but which are useful for using the array API: - `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array object. +- `is_numpy_array(x)`, `is_cupy_array(x)`, `is_torch_array(x)`, + `is_dask_array(x)`, `is_jax_array(x)`: return `True` if `x` is an array from + the corresponding library. These functions do not import the underlying + library if it has not already been imported, so they are cheap to use. + - `array_namespace(*xs)`: Get the corresponding array API namespace for the arrays `xs`. For example, if the arrays are NumPy arrays, the returned namespace will be `array_api_compat.numpy`. Note that this function will @@ -219,6 +234,12 @@ version. The minimum supported PyTorch version is 1.13. +### JAX + +Unlike the other libraries supported here, JAX array API support is contained +entirely in the JAX library. The JAX array API support is tracked at +https://github.com/google/jax/issues/18353. + ## Vendoring This library supports vendoring as an installation method. To vendor the diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index b941a31e..3317899b 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -3,6 +3,11 @@ device, get_namespace, is_array_api_obj, + is_cupy_array, + is_dask_array, + is_jax_array, + is_numpy_array, + is_torch_array, size, to_device, ) @@ -12,6 +17,11 @@ "device", "get_namespace", "is_array_api_obj", + "is_cupy_array", + "is_dask_array", + "is_jax_array", + "is_numpy_array", + "is_torch_array", "size", "to_device", ] diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index b58fb0ca..1eeb0594 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -14,7 +14,7 @@ from types import ModuleType import inspect -from ._helpers import _check_device, _is_numpy_array, array_namespace +from ._helpers import _check_device, is_numpy_array, array_namespace # These functions are modified from the NumPy versions. @@ -310,7 +310,7 @@ def _asarray( raise ValueError("Unrecognized namespace argument to asarray()") _check_device(xp, device) - if _is_numpy_array(obj): + if is_numpy_array(obj): import numpy as np if hasattr(np, '_CopyMode'): # Not present in older NumPys diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index ac866551..5e59c7ea 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -11,12 +11,13 @@ if TYPE_CHECKING: from typing import Optional, Union, Any - from ._typing import Array, Device + from ._typing import Array, Device import sys import math +import inspect -def _is_numpy_array(x): +def is_numpy_array(x): # Avoid importing NumPy if it isn't already if 'numpy' not in sys.modules: return False @@ -26,7 +27,7 @@ def _is_numpy_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, (np.ndarray, np.generic)) -def _is_cupy_array(x): +def is_cupy_array(x): # Avoid importing NumPy if it isn't already if 'cupy' not in sys.modules: return False @@ -36,7 +37,7 @@ def _is_cupy_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, (cp.ndarray, cp.generic)) -def _is_torch_array(x): +def is_torch_array(x): # Avoid importing torch if it isn't already if 'torch' not in sys.modules: return False @@ -46,7 +47,7 @@ def _is_torch_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, torch.Tensor) -def _is_dask_array(x): +def is_dask_array(x): # Avoid importing dask if it isn't already if 'dask.array' not in sys.modules: return False @@ -55,14 +56,24 @@ def _is_dask_array(x): return isinstance(x, dask.array.Array) +def is_jax_array(x): + # Avoid importing jax if it isn't already + if 'jax' not in sys.modules: + return False + + import jax + + return isinstance(x, jax.Array) + def is_array_api_obj(x): """ Check if x is an array API compatible array object. """ - return _is_numpy_array(x) \ - or _is_cupy_array(x) \ - or _is_torch_array(x) \ - or _is_dask_array(x) \ + return is_numpy_array(x) \ + or is_cupy_array(x) \ + or is_torch_array(x) \ + or is_dask_array(x) \ + or is_jax_array(x) \ or hasattr(x, '__array_namespace__') def _check_api_version(api_version): @@ -87,7 +98,7 @@ def your_function(x, y): """ namespaces = set() for x in xs: - if _is_numpy_array(x): + if is_numpy_array(x): _check_api_version(api_version) if _use_compat: from .. import numpy as numpy_namespace @@ -95,7 +106,7 @@ def your_function(x, y): else: import numpy as np namespaces.add(np) - elif _is_cupy_array(x): + elif is_cupy_array(x): _check_api_version(api_version) if _use_compat: from .. import cupy as cupy_namespace @@ -103,7 +114,7 @@ def your_function(x, y): else: import cupy as cp namespaces.add(cp) - elif _is_torch_array(x): + elif is_torch_array(x): _check_api_version(api_version) if _use_compat: from .. import torch as torch_namespace @@ -111,13 +122,19 @@ def your_function(x, y): else: import torch namespaces.add(torch) - elif _is_dask_array(x): + elif is_dask_array(x): _check_api_version(api_version) if _use_compat: from ..dask import array as dask_namespace namespaces.add(dask_namespace) else: raise TypeError("_use_compat cannot be False if input array is a dask array!") + elif is_jax_array(x): + _check_api_version(api_version) + # jax.experimental.array_api is already an array namespace. We do + # not have a wrapper submodule for it. + import jax.experimental.array_api as jnp + namespaces.add(jnp) elif hasattr(x, '__array_namespace__'): namespaces.add(x.__array_namespace__(api_version=api_version)) else: @@ -142,7 +159,7 @@ def _check_device(xp, device): if device not in ["cpu", None]: raise ValueError(f"Unsupported device for NumPy: {device!r}") -# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray +# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of @@ -162,8 +179,17 @@ def device(x: Array, /) -> Device: out: device a ``device`` object (see the "Device Support" section of the array API specification). """ - if _is_numpy_array(x): + if is_numpy_array(x): return "cpu" + if is_jax_array(x): + # JAX has .device() as a method, but it is being deprecated so that it + # can become a property, in accordance with the standard. In order for + # this function to not break when JAX makes the flip, we check for + # both here. + if inspect.ismethod(x.device): + return x.device() + else: + return x.device return x.device # Based on cupy.array_api.Array.to_device @@ -231,24 +257,28 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] .. note:: If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. """ - if _is_numpy_array(x): + if is_numpy_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") if device == 'cpu': return x raise ValueError(f"Unsupported device {device!r}") - elif _is_cupy_array(x): + elif is_cupy_array(x): # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) - elif _is_torch_array(x): + elif is_torch_array(x): return _torch_to_device(x, device, stream=stream) - elif _is_dask_array(x): + elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") # TODO: What if our array is on the GPU already? if device == 'cpu': return x raise ValueError(f"Unsupported device {device!r}") + elif is_jax_array(x): + # This import adds to_device to x + import jax.experimental.array_api # noqa: F401 + return x.to_device(device, stream=stream) return x.to_device(device, stream=stream) def size(x): diff --git a/tests/_helpers.py b/tests/_helpers.py index 69952118..e05ae86c 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,9 +1,20 @@ from importlib import import_module +import sys + import pytest -def import_or_skip_cupy(library): - if "cupy" in library: +def import_(library, wrapper=False): + if library == 'cupy': return pytest.importorskip(library) + if 'jax' in library and sys.version_info <= (3, 8): + pytest.skip('JAX array API support does not support Python 3.8') + + if wrapper: + if 'jax' in library: + library = 'jax.experimental.array_api' + else: + library = 'array_api_compat.' + library + return import_module(library) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 2c596d70..7aaef971 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -1,3 +1,6 @@ +import subprocess +import sys + import numpy as np import pytest import torch @@ -5,13 +8,12 @@ import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_or_skip_cupy - +from ._helpers import import_ -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) @pytest.mark.parametrize("api_version", [None, "2021.12"]) def test_array_namespace(library, api_version): - xp = import_or_skip_cupy(library) + xp = import_(library) array = xp.asarray([1.0, 2.0, 3.0]) namespace = array_api_compat.array_namespace(array, api_version=api_version) @@ -21,9 +23,31 @@ def test_array_namespace(library, api_version): else: if library == "dask.array": assert namespace == array_api_compat.dask.array + elif library == "jax.numpy": + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == getattr(array_api_compat, library) + # Check that array_namespace works even if jax.experimental.array_api + # hasn't been imported yet (it monkeypatches __array_namespace__ + # onto JAX arrays, but we should support them regardless). The only way to + # do this is to use a subprocess, since we cannot un-import it and another + # test probably already imported it. + if library == "jax.numpy" and sys.version_info >= (3, 9): + code = f"""\ +import sys +import jax.numpy +import array_api_compat +array = jax.numpy.asarray([1.0, 2.0, 3.0]) + +assert 'jax.experimental.array_api' not in sys.modules +namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) + +import jax.experimental.array_api +assert namespace == jax.experimental.array_api +""" + subprocess.run([sys.executable, "-c", code], check=True) def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) diff --git a/tests/test_common.py b/tests/test_common.py index bfaf58d2..b84dfdde 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,10 +1,49 @@ -import numpy as np +from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401 + is_dask_array, is_jax_array) + +from array_api_compat import is_array_api_obj, device, to_device + +from ._helpers import import_ + import pytest +import numpy as np from numpy.testing import assert_allclose -from array_api_compat import to_device +is_functions = { + 'numpy': 'is_numpy_array', + 'cupy': 'is_cupy_array', + 'torch': 'is_torch_array', + 'dask.array': 'is_dask_array', + 'jax.numpy': 'is_jax_array', +} + +@pytest.mark.parametrize('library', is_functions.keys()) +@pytest.mark.parametrize('func', is_functions.values()) +def test_is_xp_array(library, func): + lib = import_(library) + is_func = globals()[func] + + x = lib.asarray([1, 2, 3]) + + assert is_func(x) == (func == is_functions[library]) -from ._helpers import import_or_skip_cupy + assert is_array_api_obj(x) + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) +def test_device(library): + if library == "dask.array": + pytest.xfail("device() needs to be fixed for dask") + + xp = import_(library, wrapper=True) + + # We can't test much for device() and to_device() other than that + # x.to_device(x.device) works. + + x = xp.asarray([1, 2, 3]) + dev = device(x) + + x2 = to_device(x, dev) + assert device(x) == device(x2) @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) @@ -13,8 +52,8 @@ def test_to_device_host(library): # for DtoH transfers; ensure that we support a portable # shim for common array libs # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919 - xp = import_or_skip_cupy("array_api_compat." + library) - + xp = import_(library, wrapper=True) + expected = np.array([1, 2, 3]) x = xp.asarray([1, 2, 3]) x = to_device(x, "cpu") diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index c27334da..f4c245f4 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -5,7 +5,7 @@ import pytest -from ._helpers import import_or_skip_cupy +from ._helpers import import_ # Check the known dtypes by their string names @@ -64,9 +64,9 @@ def isdtype_(dtype_, kind): assert type(res) is bool # noqa: E721 return res -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) def test_isdtype_spec_dtypes(library): - xp = import_or_skip_cupy('array_api_compat.' + library) + xp = import_(library, wrapper=True) isdtype = xp.isdtype @@ -98,10 +98,10 @@ def test_isdtype_spec_dtypes(library): 'bfloat16', ] -@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"]) @pytest.mark.parametrize("dtype_", additional_dtypes) def test_isdtype_additional_dtypes(library, dtype_): - xp = import_or_skip_cupy('array_api_compat.' + library) + xp = import_(library, wrapper=True) isdtype = xp.isdtype