Skip to content

Commit

Permalink
Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 28, 2023
1 parent c855bb0 commit c9c238b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Remember to align the itemized text with the first line of an item within a list
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.
* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
keyword arguments has been deprecated, to match `numpy.where`.
* Passing non-arraylike arguments to {func}`jax.numpy.array_equal` and
{func}`jax.numpy.array_equiv` is deprecated and now raises a {obj}`DeprecationWaning`.
In the future this will raise a {obj}`TypeError`.


## jaxlib 0.4.21
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,8 @@ def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore

@util._wraps(np.array_equal)
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
# TODO(jakevdp): Non-array input deprecated 2023-11-23; change to error.
util.check_arraylike("array_equal", a1, a2, emit_warning=True)
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
Expand All @@ -2310,6 +2312,8 @@ def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:

@util._wraps(np.array_equiv)
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
# TODO(jakevdp): Non-array input deprecated 2023-11-23; change to error.
util.check_arraylike("array_equiv", a1, a2, emit_warning=True)
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def check_arraylike(fun_name: str, *args: Any, emit_warning=False, stacklevel=3)
if not _arraylike(arg))
msg = f"{fun_name} requires ndarray or scalar arguments, got {type(arg)} at position {pos}."
if emit_warning:
warnings.warn(msg + "In a future JAX release this will be an error.",
warnings.warn(msg + " In a future JAX release this will be an error.",
category=DeprecationWarning, stacklevel=stacklevel)
else:
raise TypeError(msg.format(fun_name, type(arg), pos))
Expand Down

0 comments on commit c9c238b

Please sign in to comment.