diff --git a/CHANGELOG.md b/CHANGELOG.md index 146808d9c8fe..4d0e1e3fd03a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,20 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.3.2 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.1...main). +* Changes: + * `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by + default. To recover the previous behavior, use the `jax.test_util.with_config` + decorator: + ```python + @jtu.with_config(jax_numpy_rank_promotion='allow') + class MyTestCase(jtu.JaxTestCase): + ... + ``` + * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were + deprecated in 0.2.22, have been removed. Please use + [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + instead, e.g., `x.at[idx].set(y)`. + ## jaxlib 0.3.1 (Unreleased) * Changes diff --git a/docs/jax.ops.rst b/docs/jax.ops.rst index fa45f4fa6ef6..80f9064f03f2 100644 --- a/docs/jax.ops.rst +++ b/docs/jax.ops.rst @@ -1,6 +1,6 @@ jax.ops package -================= +=============== .. currentmodule:: jax.ops @@ -8,68 +8,12 @@ jax.ops package .. _syntactic-sugar-for-ops: -Indexed update operators ------------------------- +The functions ``jax.ops.index_update``, ``jax.ops.index_add``, etc., which were +deprecated in JAX 0.2.22, have been removed. Please use the +:attr:`jax.numpy.ndarray.at` property on JAX arrays instead. -JAX is intended to be used with a functional style of programming, and -does not support NumPy-style indexed assignment directly. Instead, JAX provides -alternative pure functional operators for indexed updates to arrays. - -JAX array types have a property ``at``, which can be used as -follows (where ``idx`` is a NumPy index expression). - -========================= =================================================== -Alternate syntax Equivalent in-place expression -========================= =================================================== -``x.at[idx].get()`` ``x[idx]`` -``x.at[idx].set(y)`` ``x[idx] = y`` -``x.at[idx].add(y)`` ``x[idx] += y`` -``x.at[idx].multiply(y)`` ``x[idx] *= y`` -``x.at[idx].divide(y)`` ``x[idx] /= y`` -``x.at[idx].power(y)`` ``x[idx] **= y`` -``x.at[idx].min(y)`` ``x[idx] = np.minimum(x[idx], y)`` -``x.at[idx].max(y)`` ``x[idx] = np.maximum(x[idx], y)`` -========================= =================================================== - -None of these expressions modify the original `x`; instead they return -a modified copy of `x`. However, inside a :py:func:`jit` compiled function, -expressions like ``x = x.at[idx].set(y)`` are guaranteed to be applied in-place. - -Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple -indices refer to the same location, all updates will be applied (NumPy would -only apply the last update, rather than applying all updates.) The order -in which conflicting updates are applied is implementation-defined and may be -nondeterministic (e.g., due to concurrency on some hardware platforms). - -By default, JAX assumes that all indices are in-bounds. There is experimental -support for giving more precise semantics to out-of-bounds indexed accesses, -via the ``mode`` parameter to functions such as ``get`` and ``set``. Valid -values for ``mode`` include ``"clip"``, which means that out-of-bounds indices -will be clamped into range, and ``"fill"``/``"drop"``, which are aliases and -mean that out-of-bounds reads will be filled with a scalar ``fill_value``, -and out-of-bounds writes will be discarded. - - -Indexed update functions (deprecated) -------------------------------------- - -The following functions are aliases for the ``x.at[idx].set(y)`` -style operators. Use the ``x.at[idx]`` operators instead. - -.. autosummary:: - :toctree: _autosummary - - index - index_update - index_add - index_mul - index_min - index_max - - - -Other operators ---------------- +Segment reduction operators +--------------------------- .. autosummary:: :toctree: _autosummary diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 6f3c60448f82..2a3002298e81 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -14,7 +14,6 @@ # Helpers for indexed updates. -import warnings import sys from typing import Any, Callable, Optional, Sequence, Tuple, Union @@ -113,279 +112,6 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, return lax._convert_element_type(out, dtype, weak_type) -class _Indexable(object): - """Helper object for building indexes for indexed update functions. - - .. deprecated:: 0.2.22 - Prefer the use of :attr:`jax.numpy.ndarray.at`. If an explicit index - is needed, use :func:`jax.numpy.index_exp`. - - This is a singleton object that overrides the :code:`__getitem__` method - to return the index it is passed. - - >>> jax.ops.index[1:2, 3, None, ..., ::2] - (slice(1, 2, None), 3, None, Ellipsis, slice(None, None, 2)) - """ - __slots__ = () - - def __getitem__(self, index): - return index - -#: Index object singleton -index = _Indexable() - - -def index_add(x: Array, - idx: Index, - y: Numeric, - indices_are_sorted: bool = False, - unique_indices: bool = False) -> Array: - """Pure equivalent of :code:`x[idx] += y`. - - .. deprecated:: 0.2.22 - Prefer the use of :attr:`jax.numpy.ndarray.at`. - - Returns the value of `x` that would result from the - NumPy-style :mod:`indexed assignment `:: - - x[idx] += y - - Note the `index_add` operator is pure; `x` itself is - not modified, instead the new value that `x` would have taken is returned. - - Unlike the NumPy code :code:`x[idx] += y`, if multiple indices refer to the - same location the updates will be summed. (NumPy would only apply the last - update, rather than summing the updates.) The order in which conflicting - updates are applied is implementation-defined and may be nondeterministic - (e.g., due to concurrency on some hardware platforms). - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, `slice` objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. A - convenient syntactic sugar for forming indices is via the - :data:`jax.ops.index` object. - y: the array of updates. `y` must be broadcastable to the shape of the - array that would be returned by `x[idx]`. - indices_are_sorted: whether `idx` is known to be sorted - unique_indices: whether `idx` is known to be free of duplicates - - Returns: - An array. - - >>> x = jax.numpy.ones((5, 6)) - >>> jax.ops.index_add(x, jnp.index_exp[2:4, 3:], 6.) - DeviceArray([[1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 7., 7., 7.], - [1., 1., 1., 7., 7., 7.], - [1., 1., 1., 1., 1., 1.]], dtype=float32) - """ - warnings.warn("index_add is deprecated. Use x.at[idx].add(y) instead.", - DeprecationWarning) - return _scatter_update( - x, idx, y, lax.scatter_add, indices_are_sorted, unique_indices) - - -def index_mul(x: Array, - idx: Index, - y: Numeric, - indices_are_sorted: bool = False, - unique_indices: bool = False) -> Array: - """Pure equivalent of :code:`x[idx] *= y`. - - .. deprecated:: 0.2.22 - Prefer the use of :attr:`jax.numpy.ndarray.at`. - - Returns the value of `x` that would result from the - NumPy-style :mod:`indexed assignment `:: - - x[idx] *= y - - Note the `index_mul` operator is pure; `x` itself is - not modified, instead the new value that `x` would have taken is returned. - - Unlike the NumPy code :code:`x[idx] *= y`, if multiple indices refer to the - same location the updates will be multiplied. (NumPy would only apply the last - update, rather than multiplying the updates.) The order in which conflicting - updates are applied is implementation-defined and may be nondeterministic - (e.g., due to concurrency on some hardware platforms). - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, `slice` objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. A - convenient syntactic sugar for forming indices is via the - :data:`jax.ops.index` object. - y: the array of updates. `y` must be broadcastable to the shape of the - array that would be returned by `x[idx]`. - indices_are_sorted: whether `idx` is known to be sorted - unique_indices: whether `idx` is known to be free of duplicates - - Returns: - An array. - - >>> x = jax.numpy.ones((5, 6)) - >>> jax.ops.index_mul(x, jnp.index_exp[2:4, 3:], 6.) - DeviceArray([[1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 6., 6., 6.], - [1., 1., 1., 6., 6., 6.], - [1., 1., 1., 1., 1., 1.]], dtype=float32) - """ - warnings.warn("index_mul is deprecated. Use x.at[idx].mul(y) instead.", - DeprecationWarning) - return _scatter_update(x, idx, y, lax.scatter_mul, - indices_are_sorted, unique_indices) - - -def index_min(x: Array, - idx: Index, - y: Numeric, - indices_are_sorted: bool = False, - unique_indices: bool = False) -> Array: - """Pure equivalent of :code:`x[idx] = minimum(x[idx], y)`. - - .. deprecated:: 0.2.22 - Prefer the use of :attr:`jax.numpy.ndarray.at`. - - Returns the value of `x` that would result from the - NumPy-style :mod:`indexed assignment `:: - - x[idx] = minimum(x[idx], y) - - Note the `index_min` operator is pure; `x` itself is - not modified, instead the new value that `x` would have taken is returned. - - Unlike the NumPy code :code:`x[idx] = minimum(x[idx], y)`, if multiple indices - refer to the same location the final value will be the overall min. (NumPy - would only look at the last update, rather than all of the updates.) - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, `slice` objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. A - convenient syntactic sugar for forming indices is via the - :data:`jax.ops.index` object. - y: the array of updates. `y` must be broadcastable to the shape of the - array that would be returned by `x[idx]`. - indices_are_sorted: whether `idx` is known to be sorted - unique_indices: whether `idx` is known to be free of duplicates - - Returns: - An array. - - >>> x = jax.numpy.ones((5, 6)) - >>> jax.ops.index_min(x, jnp.index_exp[2:4, 3:], 0.) - DeviceArray([[1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 0., 0., 0.], - [1., 1., 1., 0., 0., 0.], - [1., 1., 1., 1., 1., 1.]], dtype=float32) - """ - warnings.warn("index_min is deprecated. Use x.at[idx].min(y) instead.", - DeprecationWarning) - return _scatter_update( - x, idx, y, lax.scatter_min, indices_are_sorted, unique_indices) - -def index_max(x: Array, - idx: Index, - y: Numeric, - indices_are_sorted: bool = False, - unique_indices: bool = False) -> Array: - """Pure equivalent of :code:`x[idx] = maximum(x[idx], y)`. - - .. deprecated:: 0.2.22 - Prefer the use of :attr:`jax.numpy.ndarray.at`. - - Returns the value of `x` that would result from the - NumPy-style :mod:`indexed assignment `:: - - x[idx] = maximum(x[idx], y) - - Note the `index_max` operator is pure; `x` itself is - not modified, instead the new value that `x` would have taken is returned. - - Unlike the NumPy code :code:`x[idx] = maximum(x[idx], y)`, if multiple indices - refer to the same location the final value will be the overall max. (NumPy - would only look at the last update, rather than all of the updates.) - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, `slice` objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. A - convenient syntactic sugar for forming indices is via the - :data:`jax.ops.index` object. - y: the array of updates. `y` must be broadcastable to the shape of the - array that would be returned by `x[idx]`. - indices_are_sorted: whether `idx` is known to be sorted - unique_indices: whether `idx` is known to be free of duplicates - - Returns: - An array. - - >>> x = jax.numpy.ones((5, 6)) - >>> jax.ops.index_max(x, jnp.index_exp[2:4, 3:], 6.) - DeviceArray([[1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 6., 6., 6.], - [1., 1., 1., 6., 6., 6.], - [1., 1., 1., 1., 1., 1.]], dtype=float32) - """ - warnings.warn("index_max is deprecated. Use x.at[idx].max(y) instead.", - DeprecationWarning) - return _scatter_update( - x, idx, y, lax.scatter_max, indices_are_sorted, unique_indices) - -def index_update(x: Array, - idx: Index, - y: Numeric, - indices_are_sorted: bool = False, - unique_indices: bool = False) -> Array: - """Pure equivalent of :code:`x[idx] = y`. - - .. deprecated:: 0.2.22 - Prefer the use of :attr:`jax.numpy.ndarray.at`. - - Returns the value of `x` that would result from the - NumPy-style :mod:`indexed assignment `:: - - x[idx] = y - - Note the `index_update` operator is pure; `x` itself is - not modified, instead the new value that `x` would have taken is returned. - - Unlike NumPy's :code:`x[idx] = y`, if multiple indices refer to the same - location it is undefined which update is chosen; JAX may choose the order of - updates arbitrarily and nondeterministically (e.g., due to concurrent - updates on some hardware platforms). - - Args: - x: an array with the values to be updated. - idx: a Numpy-style index, consisting of `None`, integers, `slice` objects, - ellipses, ndarrays with integer dtypes, or a tuple of the above. A - convenient syntactic sugar for forming indices is via the - :data:`jax.ops.index` object. - y: the array of updates. `y` must be broadcastable to the shape of the - array that would be returned by `x[idx]`. - indices_are_sorted: whether `idx` is known to be sorted - unique_indices: whether `idx` is known to be free of duplicates - - Returns: - An array. - - >>> x = jax.numpy.ones((5, 6)) - >>> jax.ops.index_update(x, jnp.index_exp[::2, 3:], 6.) - DeviceArray([[1., 1., 1., 6., 6., 6.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 6., 6., 6.], - [1., 1., 1., 1., 1., 1.], - [1., 1., 1., 6., 6., 6.]], dtype=float32) - """ - warnings.warn("index_update is deprecated. Use x.at[idx].set(y) instead.", - DeprecationWarning) - return _scatter_update( - x, idx, y, lax.scatter, indices_are_sorted, unique_indices) def _get_identity(op, dtype): diff --git a/jax/ops/__init__.py b/jax/ops/__init__.py index 5a1dca859a12..545142b94eba 100644 --- a/jax/ops/__init__.py +++ b/jax/ops/__init__.py @@ -14,12 +14,6 @@ # flake8: noqa: F401 from jax._src.ops.scatter import ( - index as index, - index_add as index_add, - index_mul as index_mul, - index_update as index_update, - index_min as index_min, - index_max as index_max, segment_sum as segment_sum, segment_prod as segment_prod, segment_min as segment_min,