diff --git a/CHANGELOG.md b/CHANGELOG.md index 5421e3e42b22..bfc0f30502a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ Remember to align the itemized text with the first line of an item within a list It currently is converted to NaN, and in the future will raise a {obj}`TypeError`. * Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by keyword arguments has been deprecated, to match `numpy.where`. + * Passing non-arraylike arguments to {func}`jax.numpy.array_equal` and + {func}`jax.numpy.array_equiv` is deprecated and now raises a {obj}`DeprecationWaning`. + In the future this will raise a {obj}`TypeError`. ## jaxlib 0.4.21 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8e10cd0fc0d2..49decd143bec 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2296,6 +2296,8 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore @util._wraps(np.array_equal) def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: + # TODO(jakevdp): Non-array input deprecated 2023-11-23; change to error. + util.check_arraylike("array_equal", a1, a2, emit_warning=True) try: a1, a2 = asarray(a1), asarray(a2) except Exception: @@ -2310,6 +2312,8 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: @util._wraps(np.array_equiv) def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: + # TODO(jakevdp): Non-array input deprecated 2023-11-23; change to error. + util.check_arraylike("array_equiv", a1, a2, emit_warning=True) try: a1, a2 = asarray(a1), asarray(a2) except Exception: diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 9c2c40a4281e..8c78f7702509 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -327,7 +327,7 @@ def check_arraylike(fun_name: str, *args: Any, emit_warning=False, stacklevel=3) if not _arraylike(arg)) msg = f"{fun_name} requires ndarray or scalar arguments, got {type(arg)} at position {pos}." if emit_warning: - warnings.warn(msg + "In a future JAX release this will be an error.", + warnings.warn(msg + " In a future JAX release this will be an error.", category=DeprecationWarning, stacklevel=stacklevel) else: raise TypeError(msg.format(fun_name, type(arg), pos)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 51cf262eb9a7..9b02c10dfddc 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -188,10 +188,11 @@ def testLoad(self, dtype, allow_pickle): def testArrayEqualExamples(self): # examples from the array_equal() docstring. - self.assertTrue(jnp.array_equal([1, 2], [1, 2])) self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2]))) - self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3])) - self.assertFalse(jnp.array_equal([1, 2], [1, 4])) + with self.assertWarnsRegex(DeprecationWarning, "array_equal requires ndarray or scalar arguments.*"): + self.assertTrue(jnp.array_equal([1, 2], [1, 2])) + self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3])) + self.assertFalse(jnp.array_equal([1, 2], [1, 4])) a = np.array([1, np.nan]) self.assertFalse(jnp.array_equal(a, a)) @@ -205,12 +206,13 @@ def testArrayEqualExamples(self): def testArrayEquivExamples(self): # examples from the array_equiv() docstring. - self.assertTrue(jnp.array_equiv([1, 2], [1, 2])) - self.assertFalse(jnp.array_equiv([1, 2], [1, 3])) - with jax.numpy_rank_promotion('allow'): - self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]])) - self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]])) - self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]])) + with self.assertWarnsRegex(DeprecationWarning, "array_equiv requires ndarray or scalar arguments.*"): + self.assertTrue(jnp.array_equiv([1, 2], [1, 2])) + self.assertFalse(jnp.array_equiv([1, 2], [1, 3])) + with jax.numpy_rank_promotion('allow'): + self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]])) + self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]])) + self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]])) def testArrayModule(self): if numpy_dispatch is None: