From 46d8c7ff251a74d741b0c3ea52664f21defa4886 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 4 Apr 2023 14:52:03 -0700 Subject: [PATCH] jax.numpy indexing: lower to dynamic_slice for more cases --- jax/_src/numpy/lax_numpy.py | 127 ++++++++++++++++++++++++------- tests/lax_numpy_indexing_test.py | 28 +++++++ 2 files changed, 126 insertions(+), 29 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 55d6f6aa6d25..aa2b42c7ae9d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -47,8 +47,10 @@ from jax._src import api_util from jax._src import core +from jax._src import dispatch from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple +from jax._src.array import ArrayImpl from jax._src.core import ShapedArray, ConcreteArray from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, _sort_le_comparator, PrecisionLike) @@ -58,11 +60,11 @@ from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.numpy.vectorize import vectorize +from jax._src.ops import scatter from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape -from jax._src.util import (unzip2, subvals, safe_zip, +from jax._src.util import (unzip2, unzip3, subvals, safe_zip, ceil_of_ratio, partition_list, canonicalize_axis as _canonicalize_axis) -from jax._src.array import ArrayImpl newaxis = None T = TypeVar('T') @@ -3936,39 +3938,106 @@ def replace(tup, val): ### Indexing +def _is_integer_index(idx: Any) -> bool: + return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_)) + +def _is_simple_reverse_slice(idx: Any) -> bool: + return (isinstance(idx, slice) and + idx.start is idx.stop is None and + isinstance(idx.step, int) and idx.step == -1) + +def _is_valid_integer_index_for_slice(idx, size, mode): + if size == 0: + return False + if _is_integer_index(idx): + return -size <= idx < size + try: + shape, dtype = np.shape(idx), _dtype(idx) + except: + return False + if shape == () and np.issubdtype(dtype, np.integer): + # For dynamic integer indices, dynamic_slice semantics require index clipping: + return mode in [None, 'promise_inbounds', 'clip'] + return False + +def _is_contiguous_slice(idx): + return (isinstance(idx, slice) and + (idx.start is None or _is_integer_index(idx.start)) and + (idx.stop is None or _is_integer_index(idx.stop)) and + (idx.step is None or _is_integer_index(idx.step) and idx.step == 1)) + +def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: Optional[str]) -> Optional[Array]: + # attempt to compute _rewriting_take via lax.slice(); return None if not possible. + idx = idx if isinstance(idx, tuple) else (idx,) + + if not _all(isinstance(i, int) for i in arr.shape): + return None + if len(idx) > arr.ndim: + return None + if _any(i is None for i in idx): + return None # TODO(jakevdp): handle newaxis case + + simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)} + int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape)) + if _is_valid_integer_index_for_slice(ind, size, mode)} + contiguous_slices = {i for i, ind in enumerate(idx) if _is_contiguous_slice(ind)} + + # For sharded inputs, partial slices lead to incorrect sharding semantics, so + # we fall back to gather + # TODO(yashkatariya): fix dynamic_slice with sharding + is_sharded = isinstance(arr, ArrayImpl) and not dispatch.is_single_device_sharding(arr.sharding) + all_empty_slices = _all(idx[i].indices(arr.shape[i]) == (0, arr.shape[i], 1) + for i in contiguous_slices) + if is_sharded and (int_indices or not all_empty_slices): + return None + + if len(simple_revs) + len(int_indices) + len(contiguous_slices) != len(idx): + return None + + if simple_revs: + arr = lax.rev(arr, tuple(simple_revs)) + idx = tuple(slice(None) if i in simple_revs else ind + for i, ind in enumerate(idx)) + contiguous_slices |= simple_revs + + if not int_indices and all_empty_slices: + return arr + + idx += (arr.ndim - len(idx)) * (slice(None),) + start_indices: Sequence[ArrayLike] = [] + slice_sizes: Sequence[int] = [] + + for ind, size in safe_zip(idx, arr.shape): + if isinstance(ind, slice): + start, stop, step = ind.indices(size) + assert step == 1 # checked above + start_indices.append(start) + slice_sizes.append(stop - start) + else: + assert np.issubdtype(_dtype(ind), np.integer) # checked above + assert np.shape(ind) == () # checked above + start_indices.append(ind) + slice_sizes.append(1) + # We must be careful with dtypes because dynamic_slice requires all + # start indices to have matching types. + start_indices = util.promote_dtypes(*start_indices) + arr = lax.dynamic_slice(arr, start_indices=start_indices, slice_sizes=slice_sizes) + if int_indices: + arr = lax.squeeze(arr, tuple(int_indices)) + return arr + + def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, mode=None, fill_value=None): # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. - # Handle some special cases, falling back if error messages might differ. - if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and - not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)): - if 0 <= idx < arr.shape[0]: - # Use dynamic rather than static index here to avoid slow repeated execution: - # See https://github.com/google/jax/issues/12198 - return lax.dynamic_index_in_dim(arr, idx, keepdims=False) - if (arr.ndim > 0 and isinstance(arr.shape[0], int) and - isinstance(idx, slice) and - (type(idx.start) is int or idx.start is None) and - (type(idx.stop) is int or idx.stop is None) and - (type(idx.step) is int or idx.step is None)): - n = arr.shape[0] - start = idx.start if idx.start is not None else 0 - stop = idx.stop if idx.stop is not None else n - step = idx.step if idx.step is not None else 1 - if (0 <= start < n and 0 <= stop <= n and 0 < step and - (start, stop, step) != (0, n, 1)): - if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]): - if step == 1: # TODO(mattjj, sharadmv): handle step != 1 - return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0) - elif step == 1: - # Use dynamic rather than static slice here to avoid slow repeated execution: - # See https://github.com/google/jax/issues/12198 - return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0) - else: - return lax.slice_in_dim(arr, start, stop, step) + # For simplicity of generated primitives, we call lax.dynamic_slice in the + # simplest cases: i.e. non-dynamic arrays indexed with integers and slices. + + if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None: + return result # TODO(mattjj,dougalm): expand dynamic shape indexing support if jax.config.jax_dynamic_shapes and arr.ndim > 0: diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 2e3bc3a917ad..292f03c752e0 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -854,6 +854,33 @@ def testJVPOfGradOfIndexing(self): self.assertAllClose(expected, primals) self.assertAllClose(np.zeros_like(x), tangents) + def testSimpleIndexingUsesSlice(self): + jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 7) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.dynamic_slice_p) + + jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 11) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + + jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 11) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + + jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 11) + self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.dynamic_slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + + # Simple reverses lower to lax.rev_p + jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4))) + print(jaxpr) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.rev_p) + + def testTrivialGatherIsntGenerated(self): # https://github.com/google/jax/issues/1621 jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4)) @@ -862,6 +889,7 @@ def testTrivialGatherIsntGenerated(self): jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 0) + jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 0)