Skip to content

Commit

Permalink
Lower more _rewriting_take cases to lax.slice
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 3, 2023
1 parent 607c7c1 commit e48e960
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 31 deletions.
76 changes: 47 additions & 29 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from jax._src import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
Expand All @@ -58,11 +59,11 @@
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
from jax._src.util import (unzip2, subvals, safe_zip,
from jax._src.util import (unzip2, unzip3, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl

newaxis = None
T = TypeVar('T')
Expand Down Expand Up @@ -3898,39 +3899,56 @@ def replace(tup, val):

### Indexing

def _is_integer_index(idx: Any) -> bool:
return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_))

def _is_inbound_integer_index(idx: Any, size: int) -> bool:
return _is_integer_index(idx) and -size <= idx < size

def _is_nonreversing_static_slice(idx: Any) -> bool:
return (isinstance(idx, slice) and
_all(i is None or _is_integer_index(i)
for i in [idx.start, idx.stop, idx.step]) and
(idx.step is None or idx.step > 0))

def _attempt_rewriting_take_via_slice(arr, idx):
# attempt to compute _rewriting_take via lax.slice(); return None if not possible.
idx = idx if isinstance(idx, tuple) else (idx,)
if not _all(isinstance(i, int) for i in arr.shape):
return None
if len(idx) > arr.ndim:
return None
if not _all(_is_inbound_integer_index(i, size) or _is_nonreversing_static_slice(i)
for i, size in zip(idx, arr.shape)):
return None
sqeeze_dimensions = [i for i, ind in enumerate(idx) if not isinstance(ind, slice)]
idx += (arr.ndim - len(idx)) * (slice(None),)
slices = []
for ind, size in safe_zip(idx, arr.shape):
size = int(size)
if isinstance(ind, slice):
slices.append(ind.indices(size))
else:
ind = int(ind)
ind = ind + size if ind < 0 else ind
slices.append((ind, ind + 1, 1))
return lax.squeeze(lax.slice(arr, *unzip3(slices)), sqeeze_dimensions)


def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.

# Handle some special cases, falling back if error messages might differ.
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
# Use dynamic rather than static index here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
(type(idx.stop) is int or idx.stop is None) and
(type(idx.step) is int or idx.step is None)):
n = arr.shape[0]
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else n
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
if step == 1: # TODO(mattjj, sharadmv): handle step != 1
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
elif step == 1:
# Use dynamic rather than static slice here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
else:
return lax.slice_in_dim(arr, start, stop, step)
# For simplicity of generated primitives, we call lax.slice in the simplest
# cases: i.e. non-dynamic arrays indexed with only integers and slices.
# TODO(jakevdp): we could generate slices in other situations as well:
# - newaxis -> broadcast + slice
# - negative stride -> reverse + slice
# - dynamic indices -> dynamic_slice
if (result := _attempt_rewriting_take_via_slice(arr, idx)) is not None:
return result

# TODO(mattjj,dougalm): expand dynamic shape indexing support
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
Expand Down
31 changes: 29 additions & 2 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,16 +854,43 @@ def testJVPOfGradOfIndexing(self):
self.assertAllClose(expected, primals)
self.assertAllClose(np.zeros_like(x), tangents)

def testSimpleIndexingUsesSlice(self):
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)

jaxpr = jax.make_jaxpr(lambda x: x[:, 1::2])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)

jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p)

jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p)

jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 2)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p)

def testTrivialGatherIsntGenerated(self):
# https://github.com/google/jax/issues/1621
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertNotIn('gather', str(jaxpr))

jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)

jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p)

jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
Expand Down

0 comments on commit e48e960

Please sign in to comment.