From 13dd5e42ccd50ea9185e2583a345320192da634f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 28 Nov 2023 13:55:18 -0800 Subject: [PATCH] Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv --- CHANGELOG.md | 4 ++++ jax/_src/numpy/lax_numpy.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5421e3e42b22..0a1cdfd4c532 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,10 @@ Remember to align the itemized text with the first line of an item within a list It currently is converted to NaN, and in the future will raise a {obj}`TypeError`. * Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by keyword arguments has been deprecated, to match `numpy.where`. + * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` + that cannot be converted to a JAX array is deprecated and now raises a + {obj}`DeprecationWaning`. Currently the functions return False, in the future this + will raise an exception. ## jaxlib 0.4.21 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8e10cd0fc0d2..d65f72667409 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2298,7 +2298,12 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: try: a1, a2 = asarray(a1), asarray(a2) - except Exception: + except Exception as err: + # TODO(jakevdp): Deprecated 2023-11-23; change to error. + warnings.warn("Inputs to array_equal() cannot be coerced to array. " + "Returning False; in the future this will raise an exception.\n" + f"{err!r}", + DeprecationWarning, stacklevel=2) return bool_(False) if shape(a1) != shape(a2): return bool_(False) @@ -2312,7 +2317,12 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: try: a1, a2 = asarray(a1), asarray(a2) - except Exception: + except Exception as err: + # TODO(jakevdp): Deprecated 2023-11-23; change to error. + warnings.warn("Inputs to array_equiv() cannot be coerced to array. " + "Returning False; in the future this will raise an exception.\n" + f"{err!r}", + DeprecationWarning, stacklevel=2) return bool_(False) try: eq = ufuncs.equal(a1, a2)