Skip to content

Commit

Permalink
jax.numpy indexing: lower to dynamic_slice for more cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 4, 2023
1 parent 9bb3d86 commit a963919
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 31 deletions.
126 changes: 97 additions & 29 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@

from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
Expand All @@ -58,11 +60,11 @@
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
from jax._src.util import (unzip2, subvals, safe_zip,
from jax._src.util import (unzip2, unzip3, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl

newaxis = None
T = TypeVar('T')
Expand Down Expand Up @@ -3898,39 +3900,105 @@ def replace(tup, val):

### Indexing

def _is_integer_index(idx: Any) -> bool:
return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_))

def _is_simple_reverse_slice(idx: Any) -> bool:
return (isinstance(idx, slice) and
idx.start is idx.stop is None and
isinstance(idx.step, int) and idx.step == -1)

def _is_valid_integer_index_for_slice(idx, size, mode):
if size == 0:
return False
if _is_integer_index(idx):
return -size <= idx < size
try:
shape, dtype = np.shape(idx), _dtype(idx)
except:
return False
if shape == () and np.issubdtype(dtype, np.integer):
# For dynamic integer indices, dynamic_slice semantics require index clipping:
return mode in [None, 'promise_inbounds', 'clip']
return False

def _is_contiguous_slice(idx):
return (isinstance(idx, slice) and
(idx.start is None or _is_integer_index(idx.start)) and
(idx.stop is None or _is_integer_index(idx.stop)) and
(idx.step is None or _is_integer_index(idx.step) and idx.step == 1))

def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: Optional[str]) -> Optional[Array]:
# attempt to compute _rewriting_take via lax.slice(); return None if not possible.
idx = idx if isinstance(idx, tuple) else (idx,)

if not _all(isinstance(i, int) for i in arr.shape):
return None
if len(idx) > arr.ndim:
return None
if _any(i is None for i in idx):
return None # TODO(jakevdp): handle newaxis case

simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)}
int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape))
if _is_valid_integer_index_for_slice(ind, size, mode)}
contiguous_slices = {i for i, ind in enumerate(idx) if _is_contiguous_slice(ind)}

# For sharded inputs, partial slices lead to incorrect sharding semantics, so
# we fall back to gather
# TODO(yashkatariya): fix dynamic_slice with sharding
is_sharded = isinstance(arr, ArrayImpl) and not dispatch.is_single_device_sharding(arr.sharding)
all_empty_slices = _all(idx[i].indices(arr.shape[i]) == (0, arr.shape[i], 1)
for i in contiguous_slices)
if is_sharded and (int_indices or not all_empty_slices):
return None

if len(simple_revs) + len(int_indices) + len(contiguous_slices) != len(idx):
return None

if simple_revs:
arr = lax.rev(arr, tuple(simple_revs))
if not int_indices and all_empty_slices:
return arr
idx = tuple(slice(None) if i in simple_revs else ind
for i, ind in enumerate(idx))
contiguous_slices |= simple_revs

idx += (arr.ndim - len(idx)) * (slice(None),)
start_indices: Sequence[ArrayLike] = []
slice_sizes: Sequence[int] = []

for ind, size in safe_zip(idx, arr.shape):
if isinstance(ind, slice):
start, stop, step = ind.indices(size)
assert step == 1 # checked above
start_indices.append(start)
slice_sizes.append(stop - start)
else:
assert np.issubdtype(_dtype(ind), np.integer) # checked above
assert np.shape(ind) == () # checked above
start_indices.append(ind)
slice_sizes.append(1)
# We must be careful with dtypes because dynamic_slice requires all
# start indices to have matching types.
start_indices = util.promote_dtypes(*start_indices)
arr = lax.dynamic_slice(arr, start_indices=start_indices, slice_sizes=slice_sizes)
if int_indices:
arr = lax.squeeze(arr, tuple(int_indices))
return arr


def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.

# Handle some special cases, falling back if error messages might differ.
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
# Use dynamic rather than static index here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
(type(idx.stop) is int or idx.stop is None) and
(type(idx.step) is int or idx.step is None)):
n = arr.shape[0]
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else n
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
if step == 1: # TODO(mattjj, sharadmv): handle step != 1
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
elif step == 1:
# Use dynamic rather than static slice here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
else:
return lax.slice_in_dim(arr, start, stop, step)
# For simplicity of generated primitives, we call lax.dynamic_slice in the
# simplest cases: i.e. non-dynamic arrays indexed with integers and slices.

if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None:
return result

# TODO(mattjj,dougalm): expand dynamic shape indexing support
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
Expand Down
34 changes: 32 additions & 2 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,16 +854,46 @@ def testJVPOfGradOfIndexing(self):
self.assertAllClose(expected, primals)
self.assertAllClose(np.zeros_like(x), tangents)

def testSimpleIndexingUsesSlice(self):
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 7)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p)

jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)

jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)

jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)

# Simple reverses lower to lax.rev_p
jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4)))
print(jaxpr)
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)


def testTrivialGatherIsntGenerated(self):
# https://github.com/google/jax/issues/1621
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertNotIn('gather', str(jaxpr))

jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
self.assertEqual(len(jaxpr.jaxpr.eqns), 6)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p)

jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)
self.assertEqual(len(jaxpr.jaxpr.eqns), 6)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p)

jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
Expand Down

0 comments on commit a963919

Please sign in to comment.