Skip to content

Commit

Permalink
Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 28, 2023
1 parent c855bb0 commit 43cb68a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Remember to align the itemized text with the first line of an item within a list
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
keyword arguments has been deprecated, to match `numpy.where`.
* Passing non-arraylike arguments to {func}`jax.numpy.array_equal` and
{func}`jax.numpy.array_equiv` is deprecated and now raises a {obj}`DeprecationWaning`.
In the future this will raise a {obj}`TypeError`.


## jaxlib 0.4.21
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,8 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore

@util._wraps(np.array_equal)
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
# TODO(jakevdp): Non-array input deprecated 2023-11-23; change to error.
util.check_arraylike("array_equal", a1, a2, emit_warning=True)
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
Expand All @@ -2310,6 +2312,8 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:

@util._wraps(np.array_equiv)
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
# TODO(jakevdp): Non-array input deprecated 2023-11-23; change to error.
util.check_arraylike("array_equiv", a1, a2, emit_warning=True)
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def check_arraylike(fun_name: str, *args: Any, emit_warning=False, stacklevel=3)
if not _arraylike(arg))
msg = f"{fun_name} requires ndarray or scalar arguments, got {type(arg)} at position {pos}."
if emit_warning:
warnings.warn(msg + "In a future JAX release this will be an error.",
warnings.warn(msg + " In a future JAX release this will be an error.",
category=DeprecationWarning, stacklevel=stacklevel)
else:
raise TypeError(msg.format(fun_name, type(arg), pos))
Expand Down
20 changes: 11 additions & 9 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,11 @@ def testLoad(self, dtype, allow_pickle):

def testArrayEqualExamples(self):
# examples from the array_equal() docstring.
self.assertTrue(jnp.array_equal([1, 2], [1, 2]))
self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2])))
self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3]))
self.assertFalse(jnp.array_equal([1, 2], [1, 4]))
with self.assertWarnsRegex(DeprecationWarning, "array_equal requires ndarray or scalar arguments.*"):
self.assertTrue(jnp.array_equal([1, 2], [1, 2]))
self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3]))
self.assertFalse(jnp.array_equal([1, 2], [1, 4]))

a = np.array([1, np.nan])
self.assertFalse(jnp.array_equal(a, a))
Expand All @@ -205,12 +206,13 @@ def testArrayEqualExamples(self):

def testArrayEquivExamples(self):
# examples from the array_equiv() docstring.
self.assertTrue(jnp.array_equiv([1, 2], [1, 2]))
self.assertFalse(jnp.array_equiv([1, 2], [1, 3]))
with jax.numpy_rank_promotion('allow'):
self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]]))
with self.assertWarnsRegex(DeprecationWarning, "array_equiv requires ndarray or scalar arguments.*"):
self.assertTrue(jnp.array_equiv([1, 2], [1, 2]))
self.assertFalse(jnp.array_equiv([1, 2], [1, 3]))
with jax.numpy_rank_promotion('allow'):
self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]]))
self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]]))

def testArrayModule(self):
if numpy_dispatch is None:
Expand Down

0 comments on commit 43cb68a

Please sign in to comment.