Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve documentation for jnp.searchsorted #21487

Merged
merged 1 commit into from
May 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7159,18 +7159,67 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt
return comparisons.sum(dtype=dtype, axis=0)


@util.implements(np.searchsorted,
extra_params=_dedent("""
method : str
One of 'scan' (default), 'scan_unrolled', 'sort' or 'compare_all'. Controls the method used by the
implementation: 'scan' tends to be more performant on CPU (particularly when ``a`` is
very large), 'scan_unrolled' is more performant on GPU at the expense of additional compile time,
'sort' is often more performant on accelerator backends like GPU and TPU
(particularly when ``v`` is very large), and 'compare_all' can be most performant
when ``a`` is very small."""))
@partial(jit, static_argnames=('side', 'method'))
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array:
"""Perform a binary search within a sorted array.

JAX implementation of :func:`numpy.searchsorted`.

This will return the indices within a sorted array ``a`` where values in ``v``
can be inserted to maintain its sort order.

Args:
a: one-dimensional array, assumed to be in sorted order unless ``sorter`` is specified.
v: N-dimensional array of query values
side: ``'left'`` (default) or ``'right'``; specifies whether insertion indices will be
to the left or the right in case of ties.
sorter: optional array of indices specifying the sort order of ``a``. If specified,
then the algorithm assumes that ``a[sorter]`` is in sorted order.
method: one of ``'scan'`` (default), ``'scan_unrolled'``, ``'sort'`` or ``'compare_all'``.
See *Note* below.

Returns:
Array of insertion indices of shape ``v.shape``.

Note:
The ``method`` argument controls the algorithm used to compute the insertion indices.

- ``'scan'`` (the default) tends to be more performant on CPU, particularly when ``a`` is
very large.
- ``'scan_unrolled'`` is more performant on GPU at the expense of additional compile time.
- ``'sort'`` is often more performant on accelerator backends like GPU and TPU, particularly
when ``v`` is very large.
- ``'compare_all'`` tends to be the most performant when ``a`` is very small.

Examples:
Searching for a single value:

>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)

Searching for a batch of values:

>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)

Optionally, the ``sorter`` argument can be used to find insertion indices into
an array sorted via :func:`jax.numpy.argsort`:

>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)

The result is equivalent to passing the sorted array:

>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)
"""
if sorter is None:
util.check_arraylike("searchsorted", a, v)
else:
Expand Down