diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 30fdad3c049c..620ca4b3af9e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1063,6 +1063,8 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return jitted_interp(x, xp, fp, left, right, period) +_DEPRECATED_WHERE_ARG = object() + @overload # type: ignore[no-overload-impl] def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, /, *, size: int | None = None, @@ -1081,32 +1083,6 @@ def where(condition: ArrayLike, x: ArrayLike | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> Array | tuple[Array, ...]: ... -_DEPRECATED_WHERE_ARG = object() - -@util.implements(np.where, # type: ignore[no-redef] - lax_description=_dedent(""" - At present, JAX does not support JIT-compilation of the single-argument form - of :py:func:`jax.numpy.where` because its output shape is data-dependent. The - three-argument form does not have a data-dependent shape and can be JIT-compiled - successfully. Alternatively, you can use the optional ``size`` keyword to - statically specify the expected size of the output.\n\n - - Special care is needed when the ``x`` or ``y`` input to - :py:func:`jax.numpy.where` could have a value of NaN. - Specifically, when a gradient is taken - with :py:func:`jax.grad` (reverse-mode differentiation), a NaN in either - ``x`` or ``y`` will propagate into the gradient, regardless of the value - of ``condition``. More information on this behavior and workarounds - is available in the JAX FAQ: - https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where"""), - extra_params=_dedent(""" - size : int, optional - Only referenced when ``x`` and ``y`` are ``None``. If specified, the indices of the first - ``size`` elements of the result will be returned. If there are fewer elements than ``size`` - indicates, the return value will be padded with ``fill_value``. - fill_value : array_like, optional - When ``size`` is specified and there are fewer than the indicated number of elements, the - remaining elements will be filled with ``fill_value``, which defaults to zero.""")) def where( acondition = None, if_true = None, if_false = None, /, *, size=None, fill_value=None, @@ -1114,6 +1090,65 @@ def where( condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG, y = _DEPRECATED_WHERE_ARG ) -> Array | tuple[Array, ...]: + """Select elements from two arrays based on a condition. + + JAX implementation of :func:`numpy.where`. + + .. note:: + when only ``condition`` is provided, ``jnp.where(condition)`` is equivalent + to ``jnp.nonzero(condition)``. For that case, refer to the documentation of + :func:`jax.numpy.nonzero`. The docstring below focuses on the case where + ``x`` and ``y`` are specified. + + The three-term version of ``jnp.where`` lowers to :func:`jax.lax.select`. + + Args: + condition: boolean array. Must be broadcast-compatible with ``x`` and ``y`` when + they are specified. + x: arraylike. Should be broadcast-compatible with ``condition`` and ``y``, and + typecast-compatible with ``y``. + y: arraylike. Should be broadcast-compatible with ``condition`` and ``x``, and + typecast-compatible with ``x``. + size: integer, only referenced when ``x`` and ``y`` are ``None``. For details, + see :func:`jax.numpy.nonzero`. + fill_value: only referenced when ``x`` and ``y`` are ``None``. For details, + see :func:`jax.numpy.nonzero`. + + Returns: + An array of dtype ``jnp.result_type(x, y)`` with values drawn from ``x`` where ``condition`` + is True, and from ``y`` where condition is ``False. If ``x`` and ``y`` are ``None``, the + function behaves differently; see `:func:`jax.numpy.nonzero` for a description of the return + type. + + See Also: + - :func:`jax.numpy.nonzero` + - :func:`jax.numpy.argwhere` + - :func:`jax.lax.select` + + Notes: + Special care is needed when the ``x`` or ``y`` input to :func:`jax.numpy.where` could + have a value of NaN. Specifically, when a gradient is taken with :func:`jax.grad` + (reverse-mode differentiation), a NaN in either ``x`` or ``y`` will propagate into the + gradient, regardless of the value of ``condition``. More information on this behavior + and workarounds is available in the `JAX FAQ + `_. + + Examples: + When ``x`` and ``y`` are not provided, ``where`` behaves equivalently to + :func:`jax.numpy.nonzero`: + + >>> x = jnp.arange(10) + >>> jnp.where(x > 4) + (Array([5, 6, 7, 8, 9], dtype=int32),) + >>> jnp.nonzero(x > 4) + (Array([5, 6, 7, 8, 9], dtype=int32),) + + When ``x`` and ``y`` are provided, ``where`` selects between them based on + the specified condition: + + >>> jnp.where(x > 4, x, 0) + Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32) + """ if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG or y is not _DEPRECATED_WHERE_ARG): # TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires. @@ -1431,25 +1466,87 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, return reductions.all(isclose(a, b, rtol, atol, equal_nan)) -_NONZERO_DOC = """\ -Because the size of the output of ``nonzero`` is data-dependent, the function is not -typically compatible with JIT. The JAX version adds the optional ``size`` argument which -must be specified statically for ``jnp.nonzero`` to be used within some of JAX's -transformations. -""" -_NONZERO_EXTRA_PARAMS = """ -size : int, optional - If specified, the indices of the first ``size`` True elements will be returned. If there are - fewer unique elements than ``size`` indicates, the return value will be padded with ``fill_value``. -fill_value : array_like, optional - When ``size`` is specified and there are fewer than the indicated number of elements, the - remaining elements will be filled with ``fill_value``, which defaults to zero. -""" - -@util.implements(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) def nonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None ) -> tuple[Array, ...]: + """Return indices of nonzero elements of an array. + + JAX implementation of :func:`numpy.nonzero`. + + Because the size of the output of ``nonzero`` is data-dependent, the function + is not compatible with JIT and other transformations. The JAX version adds + the optional ``size`` argument which must be specified statically for + ``jnp.nonzero`` to be used within JAX's transformations. + + Args: + a: N-dimensional array. + size: optional static integer specifying the number of nonzero entries to + return. If there are more nonzero elements than the specified ``size``, + then indices will be truncated at the end. If there are fewer nonzero + elements than the specified size, then indices will be padded with + ``fill_value``, which defaults to zero. + fill_value: optional padding value when ``size`` is specified. Defaults to 0. + + Returns: + Tuple of JAX Arrays of length ``a.ndim``, containing the indices of each + nonzero value. + + See also: + - :func:`jax.numpy flatnonzero` + - :func:`jax.numpy.where` + + Examples: + + One-dimensional array returns a length-1 tuple of indices: + + >>> x = jnp.array([0, 5, 0, 6, 0, 7]) + >>> jnp.nonzero(x) + (Array([1, 3, 5], dtype=int32),) + + Two-dimensional array returns a length-2 tuple of indices: + + >>> x = jnp.array([[0, 5, 0], + ... [6, 0, 7]]) + >>> jnp.nonzero(x) + (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) + + In either case, the resulting tuple of indices can be used directly to extract + the nonzero values: + + >>> indices = jnp.nonzero(x) + >>> x[indices] + Array([5, 6, 7], dtype=int32) + + The output of ``nonzero`` has a dynamic shape, because the number of returned + indices depends on the contents of the input array. As such, it is incompatible + with JIT and other JAX transformations: + + >>> x = jnp.array([0, 5, 0, 6, 0, 7]) + >>> jax.jit(jnp.nonzero)(x) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. + The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. + + This can be addressed by passing a static ``size`` parameter to specify the + desired output shape: + + >>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size') + >>> nonzero_jit(x, size=3) + (Array([1, 3, 5], dtype=int32),) + + If ``size`` does not match the true size, the result will be either truncated or padded: + + >>> nonzero_jit(x, size=2) # size < 3: indices are truncated + (Array([1, 3], dtype=int32),) + >>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros. + (Array([1, 3, 5, 0, 0], dtype=int32),) + + You can specify a custom fill value for the padding using the ``fill_value`` argument: + + >>> nonzero_jit(x, size=5, fill_value=len(x)) + (Array([1, 3, 5, 6, 6], dtype=int32),) + """ util.check_arraylike("nonzero", a) arr = asarray(a) del a @@ -1476,9 +1573,49 @@ def nonzero(a: ArrayLike, *, size: int | None = None, out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out)) return out -@util.implements(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS) + def flatnonzero(a: ArrayLike, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) -> Array: + """Return indices of nonzero elements in a flattened array + + JAX implementation of :func:`numpy.flatnonzero`. + + ``jnp.flatnonzero(x)`` is equivalent to ``nonzero(ravel(a))[0]``. For a full + discussion of the parameters to this function, refer to :func:`jax.numpy.nonzero`. + + Args: + a: N-dimensional array. + size: optional static integer specifying the number of nonzero entries to + return. See :func:`jax.numpy.nonzero` for more discussion of this parameter. + fill_value: optional padding value when ``size`` is specified. Defaults to 0. + See :func:`jax.numpy.nonzero` for more discussion of this parameter. + + Returns: + Array containing the indices of each nonzero value in the flattened array. + + See Also: + - :func:`jax.numpy.nonzero` + - :func:`jax.numpy.where` + + Examples: + >>> x = jnp.array([[0, 5, 0], + ... [6, 0, 8]]) + >>> jnp.flatnonzero(x) + Array([1, 3, 5], dtype=int32) + + This is equivalent to calling :func:`~jax.numpy.nonzero` on the flattened + array, and extracting the first entry in the resulting tuple: + + >>> jnp.nonzero(x.ravel())[0] + Array([1, 3, 5], dtype=int32) + + The returned indices can be used to extract nonzero entries from the + flattened array: + + >>> indices = jnp.flatnonzero(x) + >>> x.ravel()[indices] + Array([5, 6, 8], dtype=int32) + """ return nonzero(ravel(a), size=size, fill_value=fill_value)[0] @@ -3933,36 +4070,66 @@ def vander( ### Misc -_ARGWHERE_DOC = """\ -Because the size of the output of ``argwhere`` is data-dependent, the function is not -typically compatible with JIT. The JAX version adds the optional ``size`` argument, which -specifies the size of the leading dimension of the output - it must be specified statically -for ``jnp.argwhere`` to be compiled with non-static operands. If ``size`` is specified, -the indices of the first ``size`` True elements will be returned; if there are fewer -nonzero elements than `size` indicates, the index arrays will be zero-padded. -""" - - -@util.implements(np.argwhere, - lax_description=_dedent(""" - Because the size of the output of ``argwhere`` is data-dependent, the function is not - typically compatible with JIT. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.argwhere`` to be used within some of JAX's - transformations."""), - extra_params=_dedent(""" - size : int, optional - If specified, the indices of the first ``size`` True elements will be returned. If there - are fewer results than ``size`` indicates, the return value will be padded with ``fill_value``. - fill_value : array_like, optional - When ``size`` is specified and there are fewer than the indicated number of elements, the - remaining elements will be filled with ``fill_value``, which defaults to zero.""")) def argwhere( a: ArrayLike, *, size: int | None = None, fill_value: ArrayLike | None = None, ) -> Array: - result = transpose(vstack(nonzero(a, size=size, fill_value=fill_value))) + """Find the indices of nonzero array elements + + JAX implementation of :func:`numpy.argwhere`. + + ``jnp.argwhere(x)`` is essentially equivalent to ``jnp.column_stack(jnp.nonzero(x))`` + with special handling for zero-dimensional (i.e. scalar) inputs. + + Because the size of the output of ``argwhere`` is data-dependent, the function is not + typically compatible with JIT. The JAX version adds the optional ``size`` argument, which + specifies the size of the leading dimension of the output - it must be specified statically + for ``jnp.argwhere`` to be compiled with non-static operands. See :func:`jax.numpy.nonzero` + for a full discussion of ``size`` and its semantics. + + Args: + a: array for which to find nonzero elements + size: optional integer specifying statically the number of expected nonzero elements. + This must be specified in order to use ``argwhere`` within JAX transformations like + :func:`jax.jit`. See :func:`jax.numpy.nonzero` for more information. + fill_value: optional array specifying the fill value when ``size`` is specified. + See :func:`jax.numpy.nonzero` for more information. + + Returns: + a two-dimensional array of shape ``[size, x.ndim]``. If ``size`` is not specified as + an argument, it is equal to the number of nonzero elements in ``x``. + + See Also: + - :func:`jax.numpy.where` + - :func:`jax.numpy.nonzero` + + Examples: + Two-dimensional array: + + >>> x = jnp.array([[1, 0, 2], + ... [0, 3, 0]]) + >>> jnp.argwhere(x) + Array([[0, 0], + [0, 2], + [1, 1]], dtype=int32) + + Equivalent computation using :func:`jax.numpy.column_stack` and :func:`jax.numpy.nonzero`: + + >>> jnp.column_stack(jnp.nonzero(x)) + Array([[0, 0], + [0, 2], + [1, 1]], dtype=int32) + + Special case for zero-dimensional (i.e. scalar) inputs: + + >>> jnp.argwhere(1) + Array([], shape=(1, 0), dtype=int32) + >>> jnp.argwhere(0) + Array([], shape=(0, 0), dtype=int32) + """ + result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value))) if ndim(a) == 0: return result[:0].reshape(result.shape[0], 0) return result.reshape(result.shape[0], ndim(a)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 223e2cd14c07..14c7c657e406 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6017,7 +6017,9 @@ def test_lax_numpy_docstrings(self): # Test that docstring wrapping & transformation didn't fail. # Functions that have their own docstrings & don't wrap numpy. - known_exceptions = {'fromfile', 'fromiter', 'frompyfunc', 'vectorize'} + known_exceptions = { + 'fromfile', 'fromiter', 'frompyfunc', 'vectorize', + 'argwhere', 'where', 'nonzero', 'flatnonzero'} for name in dir(jnp): if name in known_exceptions or name.startswith('_'):