From e48e96051f7ea5b50c25daf29333e26fd7e45faf Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 3 Apr 2023 08:57:32 -0700 Subject: [PATCH] Lower more _rewriting_take cases to lax.slice --- jax/_src/numpy/lax_numpy.py | 76 ++++++++++++++++++++------------ tests/lax_numpy_indexing_test.py | 31 ++++++++++++- 2 files changed, 76 insertions(+), 31 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 26f5602af16b..dd99edfdbf87 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -49,6 +49,7 @@ from jax._src import core from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple +from jax._src.array import ArrayImpl from jax._src.core import ShapedArray, ConcreteArray from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, _sort_le_comparator, PrecisionLike) @@ -58,11 +59,11 @@ from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.numpy.vectorize import vectorize +from jax._src.ops import scatter from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape -from jax._src.util import (unzip2, subvals, safe_zip, +from jax._src.util import (unzip2, unzip3, subvals, safe_zip, ceil_of_ratio, partition_list, canonicalize_axis as _canonicalize_axis) -from jax._src.array import ArrayImpl newaxis = None T = TypeVar('T') @@ -3898,39 +3899,56 @@ def replace(tup, val): ### Indexing +def _is_integer_index(idx: Any) -> bool: + return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_)) + +def _is_inbound_integer_index(idx: Any, size: int) -> bool: + return _is_integer_index(idx) and -size <= idx < size + +def _is_nonreversing_static_slice(idx: Any) -> bool: + return (isinstance(idx, slice) and + _all(i is None or _is_integer_index(i) + for i in [idx.start, idx.stop, idx.step]) and + (idx.step is None or idx.step > 0)) + +def _attempt_rewriting_take_via_slice(arr, idx): + # attempt to compute _rewriting_take via lax.slice(); return None if not possible. + idx = idx if isinstance(idx, tuple) else (idx,) + if not _all(isinstance(i, int) for i in arr.shape): + return None + if len(idx) > arr.ndim: + return None + if not _all(_is_inbound_integer_index(i, size) or _is_nonreversing_static_slice(i) + for i, size in zip(idx, arr.shape)): + return None + sqeeze_dimensions = [i for i, ind in enumerate(idx) if not isinstance(ind, slice)] + idx += (arr.ndim - len(idx)) * (slice(None),) + slices = [] + for ind, size in safe_zip(idx, arr.shape): + size = int(size) + if isinstance(ind, slice): + slices.append(ind.indices(size)) + else: + ind = int(ind) + ind = ind + size if ind < 0 else ind + slices.append((ind, ind + 1, 1)) + return lax.squeeze(lax.slice(arr, *unzip3(slices)), sqeeze_dimensions) + + def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, mode=None, fill_value=None): # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. - # Handle some special cases, falling back if error messages might differ. - if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and - not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)): - if 0 <= idx < arr.shape[0]: - # Use dynamic rather than static index here to avoid slow repeated execution: - # See https://github.com/google/jax/issues/12198 - return lax.dynamic_index_in_dim(arr, idx, keepdims=False) - if (arr.ndim > 0 and isinstance(arr.shape[0], int) and - isinstance(idx, slice) and - (type(idx.start) is int or idx.start is None) and - (type(idx.stop) is int or idx.stop is None) and - (type(idx.step) is int or idx.step is None)): - n = arr.shape[0] - start = idx.start if idx.start is not None else 0 - stop = idx.stop if idx.stop is not None else n - step = idx.step if idx.step is not None else 1 - if (0 <= start < n and 0 <= stop <= n and 0 < step and - (start, stop, step) != (0, n, 1)): - if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]): - if step == 1: # TODO(mattjj, sharadmv): handle step != 1 - return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0) - elif step == 1: - # Use dynamic rather than static slice here to avoid slow repeated execution: - # See https://github.com/google/jax/issues/12198 - return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0) - else: - return lax.slice_in_dim(arr, start, stop, step) + # For simplicity of generated primitives, we call lax.slice in the simplest + # cases: i.e. non-dynamic arrays indexed with only integers and slices. + # TODO(jakevdp): we could generate slices in other situations as well: + # - newaxis -> broadcast + slice + # - negative stride -> reverse + slice + # - dynamic indices -> dynamic_slice + if (result := _attempt_rewriting_take_via_slice(arr, idx)) is not None: + return result # TODO(mattjj,dougalm): expand dynamic shape indexing support if jax.config.jax_dynamic_shapes and arr.ndim > 0: diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 2e3bc3a917ad..3b3e7925ba28 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -854,6 +854,30 @@ def testJVPOfGradOfIndexing(self): self.assertAllClose(expected, primals) self.assertAllClose(np.zeros_like(x), tangents) + def testSimpleIndexingUsesSlice(self): + jaxpr = jax.make_jaxpr(lambda x: x[:2, :2])(jnp.ones((3, 4))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + + jaxpr = jax.make_jaxpr(lambda x: x[:, 1::2])(jnp.ones((3, 4))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + + jaxpr = jax.make_jaxpr(lambda x: x[0, :2, 1])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p) + + jaxpr = jax.make_jaxpr(lambda x: x[0, 0])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p) + + jaxpr = jax.make_jaxpr(lambda x: x[:, 1])(jnp.ones((3, 4, 5))) + self.assertEqual(len(jaxpr.jaxpr.eqns), 2) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.squeeze_p) + def testTrivialGatherIsntGenerated(self): # https://github.com/google/jax/issues/1621 jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4)) @@ -861,9 +885,12 @@ def testTrivialGatherIsntGenerated(self): self.assertNotIn('gather', str(jaxpr)) jaxpr = jax.make_jaxpr(lambda x: x[0:6:1])(np.arange(4)) - self.assertEqual(len(jaxpr.jaxpr.eqns), 0) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) + jaxpr = jax.make_jaxpr(lambda x: x[:4])(np.arange(4)) - self.assertEqual(len(jaxpr.jaxpr.eqns), 0) + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.slice_p) jaxpr = jax.make_jaxpr(lambda x: x[::-1])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1)