Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: improve documentation of where, nonzero, and related functions #20941

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 234 additions & 67 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,8 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
return jitted_interp(x, xp, fp, left, right, period)


_DEPRECATED_WHERE_ARG = object()

@overload # type: ignore[no-overload-impl]
def where(condition: ArrayLike, x: Literal[None] = None,
y: Literal[None] = None, /, *, size: int | None = None,
Expand All @@ -1081,39 +1083,72 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> Array | tuple[Array, ...]: ...

_DEPRECATED_WHERE_ARG = object()

@util.implements(np.where, # type: ignore[no-redef]
lax_description=_dedent("""
At present, JAX does not support JIT-compilation of the single-argument form
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
three-argument form does not have a data-dependent shape and can be JIT-compiled
successfully. Alternatively, you can use the optional ``size`` keyword to
statically specify the expected size of the output.\n\n

Special care is needed when the ``x`` or ``y`` input to
:py:func:`jax.numpy.where` could have a value of NaN.
Specifically, when a gradient is taken
with :py:func:`jax.grad` (reverse-mode differentiation), a NaN in either
``x`` or ``y`` will propagate into the gradient, regardless of the value
of ``condition``. More information on this behavior and workarounds
is available in the JAX FAQ:
https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where"""),
extra_params=_dedent("""
size : int, optional
Only referenced when ``x`` and ``y`` are ``None``. If specified, the indices of the first
``size`` elements of the result will be returned. If there are fewer elements than ``size``
indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def where(
acondition = None, if_true = None, if_false = None, /, *,
size=None, fill_value=None,
# Deprecated keyword-only names.
condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG,
y = _DEPRECATED_WHERE_ARG
) -> Array | tuple[Array, ...]:
"""Select elements from two arrays based on a condition.

JAX implementation of :func:`numpy.where`.

.. note::
when only ``condition`` is provided, ``jnp.where(condition)`` is equivalent
to ``jnp.nonzero(condition)``. For that case, refer to the documentation of
:func:`jax.numpy.nonzero`. The docstring below focuses on the case where
``x`` and ``y`` are specified.

The three-term version of ``jnp.where`` lowers to :func:`jax.lax.select`.

Args:
condition: boolean array. Must be broadcast-compatible with ``x`` and ``y`` when
they are specified.
x: arraylike. Should be broadcast-compatible with ``condition`` and ``y``, and
typecast-compatible with ``y``.
y: arraylike. Should be broadcast-compatible with ``condition`` and ``x``, and
typecast-compatible with ``x``.
size: integer, only referenced when ``x`` and ``y`` are ``None``. For details,
see :func:`jax.numpy.nonzero`.
fill_value: only referenced when ``x`` and ``y`` are ``None``. For details,
see :func:`jax.numpy.nonzero`.

Returns:
An array of dtype ``jnp.result_type(x, y)`` with values drawn from ``x`` where ``condition``
is True, and from ``y`` where condition is ``False. If ``x`` and ``y`` are ``None``, the
function behaves differently; see `:func:`jax.numpy.nonzero` for a description of the return
type.

See Also:
- :func:`jax.numpy.nonzero`
- :func:`jax.numpy.argwhere`
- :func:`jax.lax.select`

Notes:
Special care is needed when the ``x`` or ``y`` input to :func:`jax.numpy.where` could
have a value of NaN. Specifically, when a gradient is taken with :func:`jax.grad`
(reverse-mode differentiation), a NaN in either ``x`` or ``y`` will propagate into the
gradient, regardless of the value of ``condition``. More information on this behavior
and workarounds is available in the `JAX FAQ
<https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where>`_.

Examples:
When ``x`` and ``y`` are not provided, ``where`` behaves equivalently to
:func:`jax.numpy.nonzero`:

>>> x = jnp.arange(10)
>>> jnp.where(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)
>>> jnp.nonzero(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)

When ``x`` and ``y`` are provided, ``where`` selects between them based on
the specified condition:

>>> jnp.where(x > 4, x, 0)
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
"""
if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG
or y is not _DEPRECATED_WHERE_ARG):
# TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires.
Expand Down Expand Up @@ -1431,25 +1466,87 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05,
return reductions.all(isclose(a, b, rtol, atol, equal_nan))


_NONZERO_DOC = """\
Because the size of the output of ``nonzero`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.nonzero`` to be used within some of JAX's
transformations.
"""
_NONZERO_EXTRA_PARAMS = """
size : int, optional
If specified, the indices of the first ``size`` True elements will be returned. If there are
fewer unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero.
"""

@util.implements(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def nonzero(a: ArrayLike, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> tuple[Array, ...]:
"""Return indices of nonzero elements of an array.

JAX implementation of :func:`numpy.nonzero`.

Because the size of the output of ``nonzero`` is data-dependent, the function
is not compatible with JIT and other transformations. The JAX version adds
the optional ``size`` argument which must be specified statically for
``jnp.nonzero`` to be used within JAX's transformations.

Args:
a: N-dimensional array.
size: optional static integer specifying the number of nonzero entries to
return. If there are more nonzero elements than the specified ``size``,
then indices will be truncated at the end. If there are fewer nonzero
elements than the specified size, then indices will be padded with
``fill_value``, which defaults to zero.
fill_value: optional padding value when ``size`` is specified. Defaults to 0.

Returns:
Tuple of JAX Arrays of length ``a.ndim``, containing the indices of each
nonzero value.

See also:
- :func:`jax.numpy flatnonzero`
- :func:`jax.numpy.where`

Examples:

One-dimensional array returns a length-1 tuple of indices:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jnp.nonzero(x)
(Array([1, 3, 5], dtype=int32),)

Two-dimensional array returns a length-2 tuple of indices:

>>> x = jnp.array([[0, 5, 0],
... [6, 0, 7]])
>>> jnp.nonzero(x)
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))

In either case, the resulting tuple of indices can be used directly to extract
the nonzero values:

>>> indices = jnp.nonzero(x)
>>> x[indices]
Array([5, 6, 7], dtype=int32)

The output of ``nonzero`` has a dynamic shape, because the number of returned
indices depends on the contents of the input array. As such, it is incompatible
with JIT and other JAX transformations:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jax.jit(jnp.nonzero)(x) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.

This can be addressed by passing a static ``size`` parameter to specify the
desired output shape:

>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size')
>>> nonzero_jit(x, size=3)
(Array([1, 3, 5], dtype=int32),)

If ``size`` does not match the true size, the result will be either truncated or padded:

>>> nonzero_jit(x, size=2) # size < 3: indices are truncated
(Array([1, 3], dtype=int32),)
>>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros.
(Array([1, 3, 5, 0, 0], dtype=int32),)

You can specify a custom fill value for the padding using the ``fill_value`` argument:

>>> nonzero_jit(x, size=5, fill_value=len(x))
(Array([1, 3, 5, 6, 6], dtype=int32),)
"""
util.check_arraylike("nonzero", a)
arr = asarray(a)
del a
Expand All @@ -1476,9 +1573,49 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out))
return out

@util.implements(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)

def flatnonzero(a: ArrayLike, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) -> Array:
"""Return indices of nonzero elements in a flattened array

JAX implementation of :func:`numpy.flatnonzero`.

``jnp.flatnonzero(x)`` is equivalent to ``nonzero(ravel(a))[0]``. For a full
discussion of the parameters to this function, refer to :func:`jax.numpy.nonzero`.

Args:
a: N-dimensional array.
size: optional static integer specifying the number of nonzero entries to
return. See :func:`jax.numpy.nonzero` for more discussion of this parameter.
fill_value: optional padding value when ``size`` is specified. Defaults to 0.
See :func:`jax.numpy.nonzero` for more discussion of this parameter.

Returns:
Array containing the indices of each nonzero value in the flattened array.

See Also:
- :func:`jax.numpy.nonzero`
- :func:`jax.numpy.where`

Examples:
>>> x = jnp.array([[0, 5, 0],
... [6, 0, 8]])
>>> jnp.flatnonzero(x)
Array([1, 3, 5], dtype=int32)

This is equivalent to calling :func:`~jax.numpy.nonzero` on the flattened
array, and extracting the first entry in the resulting tuple:

>>> jnp.nonzero(x.ravel())[0]
Array([1, 3, 5], dtype=int32)

The returned indices can be used to extract nonzero entries from the
flattened array:

>>> indices = jnp.flatnonzero(x)
>>> x.ravel()[indices]
Array([5, 6, 8], dtype=int32)
"""
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]


Expand Down Expand Up @@ -3933,36 +4070,66 @@ def vander(

### Misc

_ARGWHERE_DOC = """\
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument, which
specifies the size of the leading dimension of the output - it must be specified statically
for ``jnp.argwhere`` to be compiled with non-static operands. If ``size`` is specified,
the indices of the first ``size`` True elements will be returned; if there are fewer
nonzero elements than `size` indicates, the index arrays will be zero-padded.
"""


@util.implements(np.argwhere,
lax_description=_dedent("""
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.argwhere`` to be used within some of JAX's
transformations."""),
extra_params=_dedent("""
size : int, optional
If specified, the indices of the first ``size`` True elements will be returned. If there
are fewer results than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def argwhere(
a: ArrayLike,
*,
size: int | None = None,
fill_value: ArrayLike | None = None,
) -> Array:
result = transpose(vstack(nonzero(a, size=size, fill_value=fill_value)))
"""Find the indices of nonzero array elements

JAX implementation of :func:`numpy.argwhere`.

``jnp.argwhere(x)`` is essentially equivalent to ``jnp.column_stack(jnp.nonzero(x))``
with special handling for zero-dimensional (i.e. scalar) inputs.

Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument, which
specifies the size of the leading dimension of the output - it must be specified statically
for ``jnp.argwhere`` to be compiled with non-static operands. See :func:`jax.numpy.nonzero`
for a full discussion of ``size`` and its semantics.

Args:
a: array for which to find nonzero elements
size: optional integer specifying statically the number of expected nonzero elements.
This must be specified in order to use ``argwhere`` within JAX transformations like
:func:`jax.jit`. See :func:`jax.numpy.nonzero` for more information.
fill_value: optional array specifying the fill value when ``size`` is specified.
See :func:`jax.numpy.nonzero` for more information.

Returns:
a two-dimensional array of shape ``[size, x.ndim]``. If ``size`` is not specified as
an argument, it is equal to the number of nonzero elements in ``x``.

See Also:
- :func:`jax.numpy.where`
- :func:`jax.numpy.nonzero`

Examples:
Two-dimensional array:

>>> x = jnp.array([[1, 0, 2],
... [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
[0, 2],
[1, 1]], dtype=int32)

Equivalent computation using :func:`jax.numpy.column_stack` and :func:`jax.numpy.nonzero`:

>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
[0, 2],
[1, 1]], dtype=int32)

Special case for zero-dimensional (i.e. scalar) inputs:

>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)
"""
result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value)))
if ndim(a) == 0:
return result[:0].reshape(result.shape[0], 0)
return result.reshape(result.shape[0], ndim(a))
Expand Down
4 changes: 3 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6017,7 +6017,9 @@ def test_lax_numpy_docstrings(self):
# Test that docstring wrapping & transformation didn't fail.

# Functions that have their own docstrings & don't wrap numpy.
known_exceptions = {'fromfile', 'fromiter', 'frompyfunc', 'vectorize'}
known_exceptions = {
'fromfile', 'fromiter', 'frompyfunc', 'vectorize',
'argwhere', 'where', 'nonzero', 'flatnonzero'}

for name in dir(jnp):
if name in known_exceptions or name.startswith('_'):
Expand Down