diff --git a/CHANGELOG.md b/CHANGELOG.md index 5421e3e42b22..bfc0f30502a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ Remember to align the itemized text with the first line of an item within a list It currently is converted to NaN, and in the future will raise a {obj}`TypeError`. * Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by keyword arguments has been deprecated, to match `numpy.where`. + * Passing non-arraylike arguments to {func}`jax.numpy.array_equal` and + {func}`jax.numpy.array_equiv` is deprecated and now raises a {obj}`DeprecationWaning`. + In the future this will raise a {obj}`TypeError`. ## jaxlib 0.4.21 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8e10cd0fc0d2..49decd143bec 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2296,6 +2296,8 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore @util._wraps(np.array_equal) def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: + # TODO(jakevdp): Non-array input deprecated 2023-11-23; change to error. + util.check_arraylike("array_equal", a1, a2, emit_warning=True) try: a1, a2 = asarray(a1), asarray(a2) except Exception: @@ -2310,6 +2312,8 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: @util._wraps(np.array_equiv) def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: + # TODO(jakevdp): Non-array input deprecated 2023-11-23; change to error. + util.check_arraylike("array_equiv", a1, a2, emit_warning=True) try: a1, a2 = asarray(a1), asarray(a2) except Exception: diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 9c2c40a4281e..8c78f7702509 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -327,7 +327,7 @@ def check_arraylike(fun_name: str, *args: Any, emit_warning=False, stacklevel=3) if not _arraylike(arg)) msg = f"{fun_name} requires ndarray or scalar arguments, got {type(arg)} at position {pos}." if emit_warning: - warnings.warn(msg + "In a future JAX release this will be an error.", + warnings.warn(msg + " In a future JAX release this will be an error.", category=DeprecationWarning, stacklevel=stacklevel) else: raise TypeError(msg.format(fun_name, type(arg), pos))