Skip to content

Commit

Permalink
DOC: improve documentation of where, nonzero, and related functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 26, 2024
1 parent 8c2425e commit 84b2c6f
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 68 deletions.
301 changes: 234 additions & 67 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,8 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
return jitted_interp(x, xp, fp, left, right, period)


_DEPRECATED_WHERE_ARG = object()

@overload # type: ignore[no-overload-impl]
def where(condition: ArrayLike, x: Literal[None] = None,
y: Literal[None] = None, /, *, size: int | None = None,
Expand All @@ -1081,39 +1083,72 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> Array | tuple[Array, ...]: ...

_DEPRECATED_WHERE_ARG = object()

@util.implements(np.where, # type: ignore[no-redef]
lax_description=_dedent("""
At present, JAX does not support JIT-compilation of the single-argument form
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
three-argument form does not have a data-dependent shape and can be JIT-compiled
successfully. Alternatively, you can use the optional ``size`` keyword to
statically specify the expected size of the output.\n\n
Special care is needed when the ``x`` or ``y`` input to
:py:func:`jax.numpy.where` could have a value of NaN.
Specifically, when a gradient is taken
with :py:func:`jax.grad` (reverse-mode differentiation), a NaN in either
``x`` or ``y`` will propagate into the gradient, regardless of the value
of ``condition``. More information on this behavior and workarounds
is available in the JAX FAQ:
https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where"""),
extra_params=_dedent("""
size : int, optional
Only referenced when ``x`` and ``y`` are ``None``. If specified, the indices of the first
``size`` elements of the result will be returned. If there are fewer elements than ``size``
indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def where(
acondition = None, if_true = None, if_false = None, /, *,
size=None, fill_value=None,
# Deprecated keyword-only names.
condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG,
y = _DEPRECATED_WHERE_ARG
) -> Array | tuple[Array, ...]:
"""Select elements from two arrays based on a condition.
JAX implementation of :func:`numpy.where`.
.. note::
when only ``condition`` is provided, ``jnp.where(condition)`` is equivalent
to ``jnp.nonzero(condition)``. For that case, refer to the documentation of
:func:`jax.numpy.nonzero`. The docstring below focuses on the case where
``x`` and ``y`` are specified.
The three-term version of ``jnp.where`` lowers to :func:`jax.lax.select`.
Args:
condition: boolean array. Must be broadcast-compatible with ``x`` and ``y`` when
they are specified.
x: arraylike. Should be broadcast-compatible with ``condition`` and ``y``, and
typecast-compatible with ``y``.
y: arraylike. Should be broadcast-compatible with ``condition`` and ``x``, and
typecast-compatible with ``x``.
size: integer, only referenced when ``x`` and ``y`` are ``None``. For details,
see :func:`jax.numpy.nonzero`.
fill_value: only referenced when ``x`` and ``y`` are ``None``. For details,
see :func:`jax.numpy.nonzero`.
Returns:
An array of dtype ``jnp.result_type(x, y)`` with values drawn from ``x`` where ``condition``
is True, and from ``y`` where condition is ``False. If ``x`` and ``y`` are ``None``, the
function behaves differently; see `:func:`jax.numpy.nonzero` for a description of the return
type.
See Also:
- :func:`jax.numpy.nonzero`
- :func:`jax.numpy.argwhere`
- :func:`jax.lax.select`
Notes:
Special care is needed when the ``x`` or ``y`` input to :func:`jax.numpy.where` could
have a value of NaN. Specifically, when a gradient is taken with :func:`jax.grad`
(reverse-mode differentiation), a NaN in either ``x`` or ``y`` will propagate into the
gradient, regardless of the value of ``condition``. More information on this behavior
and workarounds is available in the `JAX FAQ
<https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where>`_.
Examples:
When ``x`` and ``y`` are not provided, ``where`` behaves equivalently to
:func:`jax.numpy.nonzero`:
>>> x = jnp.arange(10)
>>> jnp.where(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)
>>> jnp.nonzero(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)
When ``x`` and ``y`` are provided, ``where`` selects between them based on
the specified condition:
>>> jnp.where(x > 4, x, 0)
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)
"""
if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG
or y is not _DEPRECATED_WHERE_ARG):
# TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires.
Expand Down Expand Up @@ -1431,25 +1466,87 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05,
return reductions.all(isclose(a, b, rtol, atol, equal_nan))


_NONZERO_DOC = """\
Because the size of the output of ``nonzero`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.nonzero`` to be used within some of JAX's
transformations.
"""
_NONZERO_EXTRA_PARAMS = """
size : int, optional
If specified, the indices of the first ``size`` True elements will be returned. If there are
fewer unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero.
"""

@util.implements(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def nonzero(a: ArrayLike, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> tuple[Array, ...]:
"""Return indices of nonzero elements of an array.
JAX implementation of :func:`numpy.nonzero`.
Because the size of the output of ``nonzero`` is data-dependent, the function
is not compatible with JIT and other transformations. The JAX version adds
the optional ``size`` argument which must be specified statically for
``jnp.nonzero`` to be used within JAX's transformations.
Args:
a: N-dimensional array.
size: optional static integer specifying the number of nonzero entries to
return. If there are more nonzero elements than the specified ``size``,
then indices will be truncated at the end. If there are fewer nonzero
elements than the specified size, then indices will be padded with
``fill_value``, which defaults to zero.
fill_value: optional padding value when ``size`` is specified. Defaults to 0.
Returns:
Tuple of JAX Arrays of length ``a.ndim``, containing the indices of each
nonzero value.
See also:
- :func:`jax.numpy flatnonzero`
- :func:`jax.numpy.where`
Examples:
One-dimensional array returns a length-1 tuple of indices:
>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jnp.nonzero(x)
(Array([1, 3, 5], dtype=int32),)
Two-dimensional array returns a length-2 tuple of indices:
>>> x = jnp.array([[0, 5, 0],
... [6, 0, 7]])
>>> jnp.nonzero(x)
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
In either case, the resulting tuple of indices can be used directly to extract
the nonzero values:
>>> indices = jnp.nonzero(x)
>>> x[indices]
Array([5, 6, 7], dtype=int32)
The output of ``nonzero`` has a dynamic shape, because the number of returned
indices depends on the contents of the input array. As such, it is incompatible
with JIT and other JAX transformations:
>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jax.jit(jnp.nonzero)(x) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This can be addressed by passing a static ``size`` parameter to specify the
desired output shape:
>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size')
>>> nonzero_jit(x, size=3)
(Array([1, 3, 5], dtype=int32),)
If ``size`` does not match the true size, the result will be either truncated or padded:
>>> nonzero_jit(x, size=2) # size < 3: indices are truncated
(Array([1, 3], dtype=int32),)
>>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros.
(Array([1, 3, 5, 0, 0], dtype=int32),)
You can specify a custom fill value for the padding using the ``fill_value`` argument:
>>> nonzero_jit(x, size=5, fill_value=len(x))
(Array([1, 3, 5, 6, 6], dtype=int32),)
"""
util.check_arraylike("nonzero", a)
arr = asarray(a)
del a
Expand All @@ -1476,9 +1573,49 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out))
return out

@util.implements(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)

def flatnonzero(a: ArrayLike, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) -> Array:
"""Return indices of nonzero elements in a flattened array
JAX implementation of :func:`numpy.flatnonzero`.
``jnp.flatnonzero(x)`` is equivalent to ``nonzero(ravel(a))[0]``. For a full
discussion of the parameters to this function, refer to :func:`jax.numpy.nonzero`.
Args:
a: N-dimensional array.
size: optional static integer specifying the number of nonzero entries to
return. See :func:`jax.numpy.nonzero` for more discussion of this parameter.
fill_value: optional padding value when ``size`` is specified. Defaults to 0.
See :func:`jax.numpy.nonzero` for more discussion of this parameter.
Returns:
Array containing the indices of each nonzero value in the flattened array.
See Also:
- :func:`jax.numpy.nonzero`
- :func:`jax.numpy.where`
Examples:
>>> x = jnp.array([[0, 5, 0],
... [6, 0, 8]])
>>> jnp.flatnonzero(x)
Array([1, 3, 5], dtype=int32)
This is equivalent to calling :func:`~jax.numpy.nonzero` on the flattened
array, and extracting the first entry in the resulting tuple:
>>> jnp.nonzero(x.ravel())[0]
Array([1, 3, 5], dtype=int32)
The returned indices can be used to extract nonzero entries from the
flattened array:
>>> indices = jnp.flatnonzero(x)
>>> x.ravel()[indices]
Array([5, 6, 8], dtype=int32)
"""
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]


Expand Down Expand Up @@ -3933,36 +4070,66 @@ def vander(

### Misc

_ARGWHERE_DOC = """\
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument, which
specifies the size of the leading dimension of the output - it must be specified statically
for ``jnp.argwhere`` to be compiled with non-static operands. If ``size`` is specified,
the indices of the first ``size`` True elements will be returned; if there are fewer
nonzero elements than `size` indicates, the index arrays will be zero-padded.
"""


@util.implements(np.argwhere,
lax_description=_dedent("""
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.argwhere`` to be used within some of JAX's
transformations."""),
extra_params=_dedent("""
size : int, optional
If specified, the indices of the first ``size`` True elements will be returned. If there
are fewer results than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def argwhere(
a: ArrayLike,
*,
size: int | None = None,
fill_value: ArrayLike | None = None,
) -> Array:
result = transpose(vstack(nonzero(a, size=size, fill_value=fill_value)))
"""Find the indices of nonzero array elements
JAX implementation of :func:`numpy.argwhere`.
``jnp.argwhere(x)`` is essentially equivalent to ``jnp.column_stack(jnp.nonzero(x))``
with special handling for zero-dimensional (i.e. scalar) inputs.
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument, which
specifies the size of the leading dimension of the output - it must be specified statically
for ``jnp.argwhere`` to be compiled with non-static operands. See :func:`jax.numpy.nonzero`
for a full discussion of ``size`` and its semantics.
Args:
a: array for which to find nonzero elements
size: optional integer specifying statically the number of expected nonzero elements.
This must be specified in order to use ``argwhere`` within JAX transformations like
:func:`jax.jit`. See :func:`jax.numpy.nonzero` for more information.
fill_value: optional array specifying the fill value when ``size`` is specified.
See :func:`jax.numpy.nonzero` for more information.
Returns:
a two-dimensional array of shape ``[size, x.ndim]``. If ``size`` is not specified as
an argument, it is equal to the number of nonzero elements in ``x``.
See Also:
- :func:`jax.numpy.where`
- :func:`jax.numpy.nonzero`
Examples:
Two-dimensional array:
>>> x = jnp.array([[1, 0, 2],
... [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
[0, 2],
[1, 1]], dtype=int32)
Equivalent computation using :func:`jax.numpy.column_stack` and :func:`jax.numpy.nonzero`:
>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
[0, 2],
[1, 1]], dtype=int32)
Special case for zero-dimensional (i.e. scalar) inputs:
>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)
"""
result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value)))
if ndim(a) == 0:
return result[:0].reshape(result.shape[0], 0)
return result.reshape(result.shape[0], ndim(a))
Expand Down
4 changes: 3 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6017,7 +6017,9 @@ def test_lax_numpy_docstrings(self):
# Test that docstring wrapping & transformation didn't fail.

# Functions that have their own docstrings & don't wrap numpy.
known_exceptions = {'fromfile', 'fromiter', 'frompyfunc', 'vectorize'}
known_exceptions = {
'fromfile', 'fromiter', 'frompyfunc', 'vectorize',
'argwhere', 'where', 'nonzero', 'flatnonzero'}

for name in dir(jnp):
if name in known_exceptions or name.startswith('_'):
Expand Down

0 comments on commit 84b2c6f

Please sign in to comment.