Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy indexing: lower to dynamic_slice for more cases #15377

Merged
merged 1 commit into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 101 additions & 29 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@

from jax._src import api_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
Expand All @@ -58,11 +60,11 @@
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
from jax._src.util import (unzip2, subvals, safe_zip,
from jax._src.util import (unzip2, unzip3, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl

newaxis = None
T = TypeVar('T')
Expand Down Expand Up @@ -3968,39 +3970,109 @@ def replace(tup, val):

### Indexing

def _is_integer_index(idx: Any) -> bool:
return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_))

def _is_simple_reverse_slice(idx: Any) -> bool:
return (isinstance(idx, slice) and
idx.start is idx.stop is None and
isinstance(idx.step, int) and idx.step == -1)

def _is_valid_integer_index_for_slice(idx, size, mode):
if size == 0:
return False
if _is_integer_index(idx):
return -size <= idx < size
try:
shape, dtype = np.shape(idx), _dtype(idx)
except:
return False
if shape == () and np.issubdtype(dtype, np.integer):
# For dynamic integer indices, dynamic_slice semantics require index clipping:
return mode in [None, 'promise_inbounds', 'clip']
return False

def _is_contiguous_slice(idx):
return (isinstance(idx, slice) and
(idx.start is None or _is_integer_index(idx.start)) and
(idx.stop is None or _is_integer_index(idx.stop)) and
(idx.step is None or (_is_integer_index(idx.step) and idx.step == 1)))

def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: Optional[str]) -> Optional[Array]:
# attempt to compute _rewriting_take via lax.slice(); return None if not possible.
idx = idx if isinstance(idx, tuple) else (idx,)

if not _all(isinstance(i, int) for i in arr.shape):
return None
if len(idx) > arr.ndim:
return None
if _any(i is None for i in idx):
return None # TODO(jakevdp): handle newaxis case

simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)}
int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape))
if _is_valid_integer_index_for_slice(ind, size, mode)}
contiguous_slices = {i for i, ind in enumerate(idx) if _is_contiguous_slice(ind)}

# For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as
# opposed to x[:]) lead to incorrect sharding semantics when computed via
# dynamic_slice, so we fall back to gather.
# TODO(yashkatariya): fix dynamic_slice with sharding
is_sharded = (isinstance(arr, ArrayImpl) and
not dispatch.is_single_device_sharding(arr.sharding))
has_partial_slices = _any(idx[i].indices(arr.shape[i]) != (0, arr.shape[i], 1)
for i in contiguous_slices)
if is_sharded and (int_indices or has_partial_slices):
return None

if len(simple_revs) + len(int_indices) + len(contiguous_slices) != len(idx):
return None

if simple_revs:
arr = lax.rev(arr, tuple(simple_revs))
idx = tuple(slice(None) if i in simple_revs else ind
for i, ind in enumerate(idx))
contiguous_slices |= simple_revs

if not (int_indices or has_partial_slices):
return arr

idx += (arr.ndim - len(idx)) * (slice(None),)
start_indices: Sequence[ArrayLike] = []
slice_sizes: Sequence[int] = []

for ind, size in safe_zip(idx, arr.shape):
if isinstance(ind, slice):
start, stop, step = ind.indices(size)
assert step == 1 # checked above
start_indices.append(start)
slice_sizes.append(_max(0, stop - start))
else:
assert np.issubdtype(_dtype(ind), np.integer) # checked above
assert np.shape(ind) == () # checked above
start_indices.append(ind)
slice_sizes.append(1)
# We must be careful with dtypes because dynamic_slice requires all
# start indices to have matching types.
if len(start_indices) > 1:
start_indices = util.promote_dtypes(*start_indices)
arr = lax.dynamic_slice(arr, start_indices=start_indices, slice_sizes=slice_sizes)
if int_indices:
arr = lax.squeeze(arr, tuple(int_indices))
return arr


def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.

# Handle some special cases, falling back if error messages might differ.
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
# Use dynamic rather than static index here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
(type(idx.stop) is int or idx.stop is None) and
(type(idx.step) is int or idx.step is None)):
n = arr.shape[0]
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else n
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
if step == 1: # TODO(mattjj, sharadmv): handle step != 1
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
elif step == 1:
# Use dynamic rather than static slice here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
else:
return lax.slice_in_dim(arr, start, stop, step)
# For simplicity of generated primitives, we call lax.dynamic_slice in the
# simplest cases: i.e. non-dynamic arrays indexed with integers and slices.

if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None:
return result

# TODO(mattjj,dougalm): expand dynamic shape indexing support
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
Expand Down
6 changes: 1 addition & 5 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,7 @@ def _check(out, inp, shard_shape):
_check(arr[0:6:1], np_inp[0:6:1], (2, 2, 1))
_check(arr[:4], np_inp[:4], (2, 2, 1))
_check(arr[::-1], np_inp[::-1], (2, 2, 1))

# TODO(yashkatariya): This returns a replicated output because the int
# indexing in `_rewriting_take` goes via `dynamic_index_in_dim` rather than
# `_gather`.
# _check(arr[1], np_inp[1], (2, 1))
_check(arr[1], np_inp[1], (2, 1))

def test_array_getitem_replicated_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
Expand Down
37 changes: 37 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,32 @@ def testJVPOfGradOfIndexing(self):
self.assertAllClose(expected, primals)
self.assertAllClose(np.zeros_like(x), tangents)

def testSimpleIndexingUsesSlice(self):
jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 7)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p)

jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)

jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)

jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5)))
self.assertEqual(len(jaxpr.jaxpr.eqns), 11)
self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p)
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)

# Simple reverses lower to lax.rev_p
jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4)))
print(jaxpr)
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)

def testTrivialGatherIsntGenerated(self):
# https://github.com/google/jax/issues/1621
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
Expand All @@ -862,13 +888,24 @@ def testTrivialGatherIsntGenerated(self):

jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)

jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 0)

jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p)

def testOOBEmptySlice(self):
x = jnp.arange(4, dtype='float32')
self.assertArraysEqual(x[1:0], jnp.empty(0, dtype='float32'))
self.assertArraysEqual(x[-2:-10], jnp.empty(0, dtype='float32'))
self.assertArraysEqual(x[5:10], jnp.empty(0, dtype='float32'))

x = jnp.arange(6, dtype='float32').reshape(2, 3)
self.assertArraysEqual(x[-1:-4], jnp.empty((0, 3), dtype='float32'))
self.assertArraysEqual(x[:, 3:2], jnp.empty((2, 0), dtype='float32'))

def testIndexingEmptyDimension(self):
# Issue 2671: XLA error when indexing into dimension of size 0
x = jnp.ones((2, 0))
Expand Down