Skip to content

Commit

Permalink
Ensure that Jax float0 array is recognized
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Apr 15, 2024
1 parent faddb83 commit 9cfbb2b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
19 changes: 17 additions & 2 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@
import inspect
import warnings

def _is_jax_zero_gradient_array(x):
"""Return True if `x` is a zero-gradient array.
These arrays are a design quirk of Jax that may one day be removed.
See https://github.com/google/jax/issues/20620.
"""
if 'numpy' not in sys.modules or 'jax' not in sys.modules:
return False

import numpy as np
import jax

return isinstance(x, np.ndarray) and x.dtype == jax.float0

def is_numpy_array(x):
"""
Return True if `x` is a NumPy array.
Expand All @@ -44,7 +58,8 @@ def is_numpy_array(x):
import numpy as np

# TODO: Should we reject ndarray subclasses?
return isinstance(x, (np.ndarray, np.generic))
return (isinstance(x, (np.ndarray, np.generic))
and not _is_jax_zero_gradient_array(x))

def is_cupy_array(x):
"""
Expand Down Expand Up @@ -149,7 +164,7 @@ def is_jax_array(x):

import jax

return isinstance(x, jax.Array)
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)

def is_array_api_obj(x):
"""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import warnings

import jax
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -55,6 +56,12 @@ def test_array_namespace(library, api_version, use_compat):
"""
subprocess.run([sys.executable, "-c", code], check=True)

def test_jax_zero_gradient():
jx = jax.numpy.arange(4)
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
assert (array_api_compat.get_namespace(jax_zero) is
array_api_compat.get_namespace(jx))

def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
pytest.raises(TypeError, lambda: array_namespace())
Expand Down

0 comments on commit 9cfbb2b

Please sign in to comment.