From dca23d4d8f6fea58ce258900f9eb4dfbe20efbeb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 10 Apr 2023 12:02:59 -0700 Subject: [PATCH] jax.numpy indexing: lower to dynamic_slice for more cases --- jax/_src/numpy/lax_numpy.py | 130 ++++++++++++++++++++++++------- tests/array_test.py | 6 +- tests/lax_numpy_indexing_test.py | 37 +++++++++ 3 files changed, 139 insertions(+), 34 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 140410144e41..8bea9baf7c29 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -47,8 +47,10 @@ from jax._src import api_util from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple +from jax._src.array import ArrayImpl from jax._src.core import ShapedArray, ConcreteArray from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, _sort_le_comparator, PrecisionLike) @@ -58,11 +60,11 @@ from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.numpy.vectorize import vectorize +from jax._src.ops import scatter from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape -from jax._src.util import (unzip2, subvals, safe_zip, +from jax._src.util import (unzip2, unzip3, subvals, safe_zip, ceil_of_ratio, partition_list, canonicalize_axis as _canonicalize_axis) -from jax._src.array import ArrayImpl newaxis = None T = TypeVar('T') @@ -3968,39 +3970,109 @@ def replace(tup, val): ### Indexing +def _is_integer_index(idx: Any) -> bool: + return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_)) + +def _is_simple_reverse_slice(idx: Any) -> bool: + return (isinstance(idx, slice) and + idx.start is idx.stop is None and + isinstance(idx.step, int) and idx.step == -1) + +def _is_valid_integer_index_for_slice(idx, size, mode): + if size == 0: + return False + if _is_integer_index(idx): + return -size <= idx < size + try: + shape, dtype = np.shape(idx), _dtype(idx) + except: + return False + if shape == () and np.issubdtype(dtype, np.integer): + # For dynamic integer indices, dynamic_slice semantics require index clipping: + return mode in [None, 'promise_inbounds', 'clip'] + return False + +def _is_contiguous_slice(idx): + return (isinstance(idx, slice) and + (idx.start is None or _is_integer_index(idx.start)) and + (idx.stop is None or _is_integer_index(idx.stop)) and + (idx.step is None or (_is_integer_index(idx.step) and idx.step == 1))) + +def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: Optional[str]) -> Optional[Array]: + # attempt to compute _rewriting_take via lax.slice(); return None if not possible. + idx = idx if isinstance(idx, tuple) else (idx,) + + if not _all(isinstance(i, int) for i in arr.shape): + return None + if len(idx) > arr.ndim: + return None + if _any(i is None for i in idx): + return None # TODO(jakevdp): handle newaxis case + + simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)} + int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape)) + if _is_valid_integer_index_for_slice(ind, size, mode)} + contiguous_slices = {i for i, ind in enumerate(idx) if _is_contiguous_slice(ind)} + + # For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as + # opposed to x[:]) lead to incorrect sharding semantics when computed via + # dynamic_slice, so we fall back to gather. + # TODO(yashkatariya): fix dynamic_slice with sharding + is_sharded = (isinstance(arr, ArrayImpl) and + not dispatch.is_single_device_sharding(arr.sharding)) + has_partial_slices = _any(idx[i].indices(arr.shape[i]) != (0, arr.shape[i], 1) + for i in contiguous_slices) + if is_sharded and (int_indices or has_partial_slices): + return None + + if len(simple_revs) + len(int_indices) + len(contiguous_slices) != len(idx): + return None + + if simple_revs: + arr = lax.rev(arr, tuple(simple_revs)) + idx = tuple(slice(None) if i in simple_revs else ind + for i, ind in enumerate(idx)) + contiguous_slices |= simple_revs + + if not (int_indices or has_partial_slices): + return arr + + idx += (arr.ndim - len(idx)) * (slice(None),) + start_indices: Sequence[ArrayLike] = [] + slice_sizes: Sequence[int] = [] + + for ind, size in safe_zip(idx, arr.shape): + if isinstance(ind, slice): + start, stop, step = ind.indices(size) + assert step == 1 # checked above + start_indices.append(start) + slice_sizes.append(_max(0, stop - start)) + else: + assert np.issubdtype(_dtype(ind), np.integer) # checked above + assert np.shape(ind) == () # checked above + start_indices.append(ind) + slice_sizes.append(1) + # We must be careful with dtypes because dynamic_slice requires all + # start indices to have matching types. + if len(start_indices) > 1: + start_indices = util.promote_dtypes(*start_indices) + arr = lax.dynamic_slice(arr, start_indices=start_indices, slice_sizes=slice_sizes) + if int_indices: + arr = lax.squeeze(arr, tuple(int_indices)) + return arr + + def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, mode=None, fill_value=None): # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. - # Handle some special cases, falling back if error messages might differ. - if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and - not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)): - if 0 <= idx < arr.shape[0]: - # Use dynamic rather than static index here to avoid slow repeated execution: - # See https://github.com/google/jax/issues/12198 - return lax.dynamic_index_in_dim(arr, idx, keepdims=False) - if (arr.ndim > 0 and isinstance(arr.shape[0], int) and - isinstance(idx, slice) and - (type(idx.start) is int or idx.start is None) and - (type(idx.stop) is int or idx.stop is None) and - (type(idx.step) is int or idx.step is None)): - n = arr.shape[0] - start = idx.start if idx.start is not None else 0 - stop = idx.stop if idx.stop is not None else n - step = idx.step if idx.step is not None else 1 - if (0 <= start < n and 0 <= stop <= n and 0 < step and - (start, stop, step) != (0, n, 1)): - if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]): - if step == 1: # TODO(mattjj, sharadmv): handle step != 1 - return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0) - elif step == 1: - # Use dynamic rather than static slice here to avoid slow repeated execution: - # See https://github.com/google/jax/issues/12198 - return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0) - else: - return lax.slice_in_dim(arr, start, stop, step) + # For simplicity of generated primitives, we call lax.dynamic_slice in the + # simplest cases: i.e. non-dynamic arrays indexed with integers and slices. + + if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None: + return result # TODO(mattjj,dougalm): expand dynamic shape indexing support if jax.config.jax_dynamic_shapes and arr.ndim > 0: diff --git a/tests/array_test.py b/tests/array_test.py index 9ea5596d259e..adc663c597f0 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -505,11 +505,7 @@ def _check(out, inp, shard_shape): _check(arr[0:6:1], np_inp[0:6:1], (2, 2, 1)) _check(arr[:4], np_inp[:4], (2, 2, 1)) _check(arr[::-1], np_inp[::-1], (2, 2, 1)) - - # TODO(yashkatariya): This returns a replicated output because the int - # indexing in `_rewriting_take` goes via `dynamic_index_in_dim` rather than - # `_gather`. - # _check(arr[1], np_inp[1], (2, 1)) + _check(arr[1], np_inp[1], (2, 1)) def test_array_getitem_replicated_multi_device(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index a2c3c5bd854e..b53ae726e752 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -854,6 +854,32 @@ def testJVPOfGradOfIndexing(self): self.assertAllClose(expected, primals) self.assertAllClose(np.zeros_like(x), tangents) + def testSimpleIndexingUsesSlice(self): + jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 7) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p) + + jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 11) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + + jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 11) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + + jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 11) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + + # Simple reverses lower to lax.rev_p + jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4))) + print(jaxpr) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p) + def testTrivialGatherIsntGenerated(self): # https://github.com/google/jax/issues/1621 jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4)) @@ -862,6 +888,7 @@ def testTrivialGatherIsntGenerated(self): jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 0) + jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 0) @@ -869,6 +896,16 @@ def testTrivialGatherIsntGenerated(self): self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p) + def testOOBEmptySlice(self): + x = jnp.arange(4, dtype='float32') + self.assertArraysEqual(x[1:0], jnp.empty(0, dtype='float32')) + self.assertArraysEqual(x[-2:-10], jnp.empty(0, dtype='float32')) + self.assertArraysEqual(x[5:10], jnp.empty(0, dtype='float32')) + + x = jnp.arange(6, dtype='float32').reshape(2, 3) + self.assertArraysEqual(x[-1:-4], jnp.empty((0, 3), dtype='float32')) + self.assertArraysEqual(x[:, 3:2], jnp.empty((2, 0), dtype='float32')) + def testIndexingEmptyDimension(self): # Issue 2671: XLA error when indexing into dimension of size 0 x = jnp.ones((2, 0))