Skip to content

Commit

Permalink
Remove jax.ops.index... functions.
Browse files Browse the repository at this point in the history
These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
  • Loading branch information
hawkinsp committed Feb 24, 2022
1 parent 3948fde commit f51a05a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 342 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.2 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.1...main).
* Changes:
* `jax.test_util.JaxTestCase` now sets `jax_numpy_rank_promotion='raise'` by
default. To recover the previous behavior, use the `jax.test_util.with_config`
decorator:
```python
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
...
```
* The functions `jax.ops.index_update`, `jax.ops.index_add`, which were
deprecated in 0.2.22, have been removed. Please use
[the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html)
instead, e.g., `x.at[idx].set(y)`.


## jaxlib 0.3.1 (Unreleased)
* Changes
Expand Down
68 changes: 6 additions & 62 deletions docs/jax.ops.rst
Original file line number Diff line number Diff line change
@@ -1,75 +1,19 @@

jax.ops package
=================
===============

.. currentmodule:: jax.ops

.. automodule:: jax.ops

.. _syntactic-sugar-for-ops:

Indexed update operators
------------------------
The functions ``jax.ops.index_update``, ``jax.ops.index_add``, etc., which were
deprecated in JAX 0.2.22, have been removed. Please use the
:attr:`jax.numpy.ndarray.at` property on JAX arrays instead.

JAX is intended to be used with a functional style of programming, and
does not support NumPy-style indexed assignment directly. Instead, JAX provides
alternative pure functional operators for indexed updates to arrays.

JAX array types have a property ``at``, which can be used as
follows (where ``idx`` is a NumPy index expression).

========================= ===================================================
Alternate syntax Equivalent in-place expression
========================= ===================================================
``x.at[idx].get()`` ``x[idx]``
``x.at[idx].set(y)`` ``x[idx] = y``
``x.at[idx].add(y)`` ``x[idx] += y``
``x.at[idx].multiply(y)`` ``x[idx] *= y``
``x.at[idx].divide(y)`` ``x[idx] /= y``
``x.at[idx].power(y)`` ``x[idx] **= y``
``x.at[idx].min(y)`` ``x[idx] = np.minimum(x[idx], y)``
``x.at[idx].max(y)`` ``x[idx] = np.maximum(x[idx], y)``
========================= ===================================================

None of these expressions modify the original `x`; instead they return
a modified copy of `x`. However, inside a :py:func:`jit` compiled function,
expressions like ``x = x.at[idx].set(y)`` are guaranteed to be applied in-place.

Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple
indices refer to the same location, all updates will be applied (NumPy would
only apply the last update, rather than applying all updates.) The order
in which conflicting updates are applied is implementation-defined and may be
nondeterministic (e.g., due to concurrency on some hardware platforms).

By default, JAX assumes that all indices are in-bounds. There is experimental
support for giving more precise semantics to out-of-bounds indexed accesses,
via the ``mode`` parameter to functions such as ``get`` and ``set``. Valid
values for ``mode`` include ``"clip"``, which means that out-of-bounds indices
will be clamped into range, and ``"fill"``/``"drop"``, which are aliases and
mean that out-of-bounds reads will be filled with a scalar ``fill_value``,
and out-of-bounds writes will be discarded.


Indexed update functions (deprecated)
-------------------------------------

The following functions are aliases for the ``x.at[idx].set(y)``
style operators. Use the ``x.at[idx]`` operators instead.

.. autosummary::
:toctree: _autosummary

index
index_update
index_add
index_mul
index_min
index_max



Other operators
---------------
Segment reduction operators
---------------------------

.. autosummary::
:toctree: _autosummary
Expand Down
274 changes: 0 additions & 274 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# Helpers for indexed updates.

import warnings
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -113,279 +112,6 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
return lax._convert_element_type(out, dtype, weak_type)


class _Indexable(object):
"""Helper object for building indexes for indexed update functions.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`. If an explicit index
is needed, use :func:`jax.numpy.index_exp`.
This is a singleton object that overrides the :code:`__getitem__` method
to return the index it is passed.
>>> jax.ops.index[1:2, 3, None, ..., ::2]
(slice(1, 2, None), 3, None, Ellipsis, slice(None, None, 2))
"""
__slots__ = ()

def __getitem__(self, index):
return index

#: Index object singleton
index = _Indexable()


def index_add(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] += y`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] += y
Note the `index_add` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] += y`, if multiple indices refer to the
same location the updates will be summed. (NumPy would only apply the last
update, rather than summing the updates.) The order in which conflicting
updates are applied is implementation-defined and may be nondeterministic
(e.g., due to concurrency on some hardware platforms).
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_add(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 7., 7., 7.],
[1., 1., 1., 7., 7., 7.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_add is deprecated. Use x.at[idx].add(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter_add, indices_are_sorted, unique_indices)


def index_mul(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] *= y`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] *= y
Note the `index_mul` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] *= y`, if multiple indices refer to the
same location the updates will be multiplied. (NumPy would only apply the last
update, rather than multiplying the updates.) The order in which conflicting
updates are applied is implementation-defined and may be nondeterministic
(e.g., due to concurrency on some hardware platforms).
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_mul(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_mul is deprecated. Use x.at[idx].mul(y) instead.",
DeprecationWarning)
return _scatter_update(x, idx, y, lax.scatter_mul,
indices_are_sorted, unique_indices)


def index_min(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] = minimum(x[idx], y)`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] = minimum(x[idx], y)
Note the `index_min` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] = minimum(x[idx], y)`, if multiple indices
refer to the same location the final value will be the overall min. (NumPy
would only look at the last update, rather than all of the updates.)
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_min(x, jnp.index_exp[2:4, 3:], 0.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_min is deprecated. Use x.at[idx].min(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter_min, indices_are_sorted, unique_indices)

def index_max(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] = maximum(x[idx], y)`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] = maximum(x[idx], y)
Note the `index_max` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike the NumPy code :code:`x[idx] = maximum(x[idx], y)`, if multiple indices
refer to the same location the final value will be the overall max. (NumPy
would only look at the last update, rather than all of the updates.)
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_max(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_max is deprecated. Use x.at[idx].max(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter_max, indices_are_sorted, unique_indices)

def index_update(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] = y`.
.. deprecated:: 0.2.22
Prefer the use of :attr:`jax.numpy.ndarray.at`.
Returns the value of `x` that would result from the
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
x[idx] = y
Note the `index_update` operator is pure; `x` itself is
not modified, instead the new value that `x` would have taken is returned.
Unlike NumPy's :code:`x[idx] = y`, if multiple indices refer to the same
location it is undefined which update is chosen; JAX may choose the order of
updates arbitrarily and nondeterministically (e.g., due to concurrent
updates on some hardware platforms).
Args:
x: an array with the values to be updated.
idx: a Numpy-style index, consisting of `None`, integers, `slice` objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
:data:`jax.ops.index` object.
y: the array of updates. `y` must be broadcastable to the shape of the
array that would be returned by `x[idx]`.
indices_are_sorted: whether `idx` is known to be sorted
unique_indices: whether `idx` is known to be free of duplicates
Returns:
An array.
>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_update(x, jnp.index_exp[::2, 3:], 6.)
DeviceArray([[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.]], dtype=float32)
"""
warnings.warn("index_update is deprecated. Use x.at[idx].set(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter, indices_are_sorted, unique_indices)


def _get_identity(op, dtype):
Expand Down
Loading

0 comments on commit f51a05a

Please sign in to comment.