Skip to content

Commit

Permalink
jax.numpy indexing: lower to dynamic_slice for more cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 6, 2023
1 parent b926e04 commit 3eac0c0
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 29 deletions.
127 changes: 98 additions & 29 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@

from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
Expand All @@ -58,11 +60,11 @@
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
from jax._src.util import (unzip2, subvals, safe_zip,
from jax._src.util import (unzip2, unzip3, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl

newaxis = None
T = TypeVar('T')
Expand Down Expand Up @@ -3936,39 +3938,106 @@ def replace(tup, val):

### Indexing

def _is_integer_index(idx: Any) -> bool:
return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_))

def _is_simple_reverse_slice(idx: Any) -> bool:
return (isinstance(idx, slice) and
idx.start is idx.stop is None and
isinstance(idx.step, int) and idx.step == -1)

def _is_valid_integer_index_for_slice(idx, size, mode):
if size == 0:
return False
if _is_integer_index(idx):
return -size <= idx < size
try:
shape, dtype = np.shape(idx), _dtype(idx)
except:
return False
if shape == () and np.issubdtype(dtype, np.integer):
# For dynamic integer indices, dynamic_slice semantics require index clipping:
return mode in [None, 'promise_inbounds', 'clip']
return False

def _is_contiguous_slice(idx):
return (isinstance(idx, slice) and
(idx.start is None or _is_integer_index(idx.start)) and
(idx.stop is None or _is_integer_index(idx.stop)) and
(idx.step is None or _is_integer_index(idx.step) and idx.step == 1))

def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: Optional[str]) -> Optional[Array]:
# attempt to compute _rewriting_take via lax.slice(); return None if not possible.
idx = idx if isinstance(idx, tuple) else (idx,)

if not _all(isinstance(i, int) for i in arr.shape):
return None
if len(idx) > arr.ndim:
return None
if _any(i is None for i in idx):
return None # TODO(jakevdp): handle newaxis case

simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)}
int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape))
if _is_valid_integer_index_for_slice(ind, size, mode)}
contiguous_slices = {i for i, ind in enumerate(idx) if _is_contiguous_slice(ind)}

# For sharded inputs, partial slices lead to incorrect sharding semantics, so
# we fall back to gather
# TODO(yashkatariya): fix dynamic_slice with sharding
is_sharded = isinstance(arr, ArrayImpl) and not dispatch.is_single_device_sharding(arr.sharding)
all_empty_slices = _all(idx[i].indices(arr.shape[i]) == (0, arr.shape[i], 1)
for i in contiguous_slices)
if is_sharded and (int_indices or not all_empty_slices):
return None

if len(simple_revs) + len(int_indices) + len(contiguous_slices) != len(idx):
return None

if simple_revs:
arr = lax.rev(arr, tuple(simple_revs))
idx = tuple(slice(None) if i in simple_revs else ind
for i, ind in enumerate(idx))
contiguous_slices |= simple_revs

if not int_indices and all_empty_slices:
return arr

idx += (arr.ndim - len(idx)) * (slice(None),)
start_indices: Sequence[ArrayLike] = []
slice_sizes: Sequence[int] = []

for ind, size in safe_zip(idx, arr.shape):
if isinstance(ind, slice):
start, stop, step = ind.indices(size)
assert step == 1 # checked above
start_indices.append(start)
slice_sizes.append(_max(0, stop - start))
else:
assert np.issubdtype(_dtype(ind), np.integer) # checked above
assert np.shape(ind) == () # checked above
start_indices.append(ind)
slice_sizes.append(1)
# We must be careful with dtypes because dynamic_slice requires all
# start indices to have matching types.
start_indices = util.promote_dtypes(*start_indices)
arr = lax.dynamic_slice(arr, start_indices=start_indices, slice_sizes=slice_sizes)
if int_indices:
arr = lax.squeeze(arr, tuple(int_indices))
return arr


def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.

# Handle some special cases, falling back if error messages might differ.
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
# Use dynamic rather than static index here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
(type(idx.stop) is int or idx.stop is None) and
(type(idx.step) is int or idx.step is None)):
n = arr.shape[0]
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else n
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
if step == 1: # TODO(mattjj, sharadmv): handle step != 1
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
elif step == 1:
# Use dynamic rather than static slice here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
else:
return lax.slice_in_dim(arr, start, stop, step)
# For simplicity of generated primitives, we call lax.dynamic_slice in the
# simplest cases: i.e. non-dynamic arrays indexed with integers and slices.

if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None:
return result

# TODO(mattjj,dougalm): expand dynamic shape indexing support
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
Expand Down
37 changes: 37 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,32 @@ def testJVPOfGradOfIndexing(self):
self.assertAllClose(expected, primals)
self.assertAllClose(np.zeros_like(x), tangents)

def testSimpleIndexingUsesSlice(self):
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 7)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p)

jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)

jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)

jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)

# Simple reverses lower to lax.rev_p
jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4)))
print(jaxpr)
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)

def testTrivialGatherIsntGenerated(self):
# https://github.com/google/jax/issues/1621
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
Expand All @@ -862,13 +888,24 @@ def testTrivialGatherIsntGenerated(self):

jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)

jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)

jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)

def testOOBEmptySlice(self):
x = jnp.arange(4, dtype='float32')
self.assertArraysEqual(x[1:0], jnp.empty(0, dtype='float32'))
self.assertArraysEqual(x[-2:-10], jnp.empty(0, dtype='float32'))
self.assertArraysEqual(x[5:10], jnp.empty(0, dtype='float32'))

x = jnp.arange(6, dtype='float32').reshape(2, 3)
self.assertArraysEqual(x[-1:-4], jnp.empty((0, 3), dtype='float32'))
self.assertArraysEqual(x[:, 3:2], jnp.empty((2, 0), dtype='float32'))

def testIndexingEmptyDimension(self):
# Issue 2671: XLA error when indexing into dimension of size 0
x = jnp.ones((2, 0))
Expand Down

0 comments on commit 3eac0c0

Please sign in to comment.