diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 833604b9c537..a1046880cd60 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -60,7 +60,10 @@ jobs: os: ubuntu-latest enable-x64: 1 enable-omnistaging: 0 - package-overrides: "none" + # Test experimental NumPy dispatch + # TODO(shoyer): remove cython after + # https://github.com/seberg/numpy-dispatch/pull/5 is merged + package-overrides: "cython git+https://github.com/seberg/numpy-dispatch.git" num_generated_cases: 25 - python-version: 3.6 os: ubuntu-latest diff --git a/jax/core.py b/jax/core.py index 8a5576b916cb..1ff93a45a932 100644 --- a/jax/core.py +++ b/jax/core.py @@ -469,6 +469,9 @@ def __len__(self): def aval(self): raise NotImplementedError("must override") + # Python looks up special methods only on classes, not instances. This means + # these methods needs to be defined explicitly rather than relying on + # __getattr__ (short of using a metaclass). def __neg__(self): return self.aval._neg(self) def __pos__(self): return self.aval._pos(self) def __eq__(self, other): return self.aval._eq(self, other) @@ -528,6 +531,9 @@ def __complex__(self): def __setitem__(self, idx, val): raise TypeError("JAX 'Tracer' objects do not support item assignment") + # NumPy also only looks up special methods on classes. + def __array_module__(self, types): return self.aval._array_module(self, types) + def __getattr__(self, name): # if the aval property raises an AttributeError, gets caught here assert skip_checks or name != "aval" diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 6610a3ece318..c3c43a1131ad 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -36,6 +36,7 @@ import numpy as np import opt_einsum +import jax from jax import jit, custom_jvp from .vectorize import vectorize from ._util import _wraps @@ -4574,6 +4575,21 @@ def _operator_round(number, ndigits=None): setattr(DeviceArray, "nbytes", property(_nbytes)) +# Experimental support for NumPy's module dispatch with NEP-37. +# Currently requires https://github.com/seberg/numpy-dispatch +_JAX_ARRAY_TYPES = (UnshapedArray, DeviceArray, core.Tracer) +_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,) + +def __array_module__(self, types): + if builtins.all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): + return jax.numpy + else: + return NotImplemented + +setattr(ShapedArray, "_array_module", staticmethod(__array_module__)) +setattr(DeviceArray, "__array_module__", __array_module__) + + # Extra methods that are handy setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast)) setattr(ShapedArray, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 3a93e52284a9..2684b5454ce1 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -28,6 +28,10 @@ from absl.testing import parameterized import numpy as np +try: + import numpy_dispatch +except ImportError: + numpy_dispatch = None import jax import jax.ops @@ -585,6 +589,27 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): with self.assertRaises(TypeError): op(arg, other) + def testArrayModule(self): + if numpy_dispatch is None: + raise SkipTest('requires https://github.com/seberg/numpy-dispatch') + + jnp_array = jnp.array(1.0) + np_array = np.array(1.0) + + with numpy_dispatch.ensure_dispatching(): + module = numpy_dispatch.get_array_module(jnp_array) + self.assertIs(module, jnp) + + module = numpy_dispatch.get_array_module(jnp_array, np_array) + self.assertIs(module, jnp) + + def f(x): + module = numpy_dispatch.get_array_module(x) + self.assertIs(module, jnp) + return x + jax.jit(f)(jnp_array) + jax.grad(f)(jnp_array) + @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix(