Skip to content

Commit

Permalink
Improve documentation for jnp.searchsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 29, 2024
1 parent dbfb4b3 commit 6c4fc6a
Showing 1 changed file with 58 additions and 9 deletions.
67 changes: 58 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7159,18 +7159,67 @@ def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dt
return comparisons.sum(dtype=dtype, axis=0)


@util.implements(np.searchsorted,
extra_params=_dedent("""
method : str
One of 'scan' (default), 'scan_unrolled', 'sort' or 'compare_all'. Controls the method used by the
implementation: 'scan' tends to be more performant on CPU (particularly when ``a`` is
very large), 'scan_unrolled' is more performant on GPU at the expense of additional compile time,
'sort' is often more performant on accelerator backends like GPU and TPU
(particularly when ``v`` is very large), and 'compare_all' can be most performant
when ``a`` is very small."""))
@partial(jit, static_argnames=('side', 'method'))
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array:
"""Perform a binary search within a sorted array.
JAX implementation of :func:`numpy.searchsorted`.
This will return the indices within a sorted array ``a`` where values in ``v``
can be inserted to maintain its sort order.
Args:
a: one-dimensional array, assumed to be in sorted order unless ``sorter`` is specified.
v: N-dimensional array of query values
side: ``'left'`` (default) or ``'right'``; specifies whether insertion indices will be
to the left or the right in case of ties.
sorter: optional array of indices specifying the sort order of ``a``. If specified,
then the algorithm assumes that ``a[sorter]`` is in sorted order.
method: one of ``'scan'`` (default), ``'scan_unrolled'``, ``'sort'`` or ``'compare_all'``.
See *Note* below.
Returns:
Array of insertion indices of shape ``v.shape``.
Note:
The ``method`` argument controls the algorithm used to compute the insertion indices.
- ``'scan'`` (the default) tends to be more performant on CPU, particularly when ``a`` is
very large.
- ``'scan_unrolled'`` is more performant on GPU at the expense of additional compile time.
- ``'sort'`` is often more performant on accelerator backends like GPU and TPU, particularly
when ``v`` is very large.
- ``'compare_all'`` tends to be the most performant when ``a`` is very small.
Examples:
Searching for a single value:
>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5])
>>> jnp.searchsorted(a, 2)
Array(1, dtype=int32)
>>> jnp.searchsorted(a, 2, side='right')
Array(3, dtype=int32)
Searching for a batch of values:
>>> vals = jnp.array([0, 3, 8, 1.5, 2])
>>> jnp.searchsorted(a, vals)
Array([0, 3, 7, 1, 1], dtype=int32)
Optionally, the ``sorter`` argument can be used to find insertion indices into
an array sorted via :func:`jax.numpy.argsort`:
>>> a = jnp.array([4, 3, 5, 1, 2])
>>> sorter = jnp.argsort(a)
>>> jnp.searchsorted(a, vals, sorter=sorter)
Array([0, 2, 5, 1, 1], dtype=int32)
The result is equivalent to passing the sorted array:
>>> jnp.searchsorted(jnp.sort(a), vals)
Array([0, 2, 5, 1, 1], dtype=int32)
"""
if sorter is None:
util.check_arraylike("searchsorted", a, v)
else:
Expand Down

0 comments on commit 6c4fc6a

Please sign in to comment.