From a8ad11706e4e89de2a1a0440bf3c398722b2562a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 31 Mar 2023 15:14:39 -0700 Subject: [PATCH] Lower scatter to dynamic_update_slice --- jax/_src/numpy/array_methods.py | 6 ++++ jax/_src/ops/scatter.py | 55 +++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index dd7b8be010ac..b40fe7e4cc61 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -521,6 +521,12 @@ def set(self, values, *, indices_are_sorted=False, unique_indices=False, See :mod:`jax.ops` for details. """ + result = scatter._try_scatter_update_via_dynamic_slice( + self.array, self.index, values, + indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode) + if result is not None: + return result return scatter._scatter_update(self.array, self.index, values, lax.scatter, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index b3e70793373f..a08153be3ff0 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -42,6 +42,61 @@ Numeric = Union[Array, Scalar] +def _is_integer_index(idx: Any) -> bool: + return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_)) + +def _is_inbound_integer_index(idx: Any, size: int) -> bool: + return _is_integer_index(idx) and -size <= idx < size + +def _is_simple_slice(idx: Any) -> bool: + return (isinstance(idx, slice) and + (idx.start is None or _is_integer_index(idx.start)) and + (idx.stop is None or _is_integer_index(idx.stop)) and + (idx.step is None or idx.step == 1)) + +def _try_scatter_update_via_dynamic_slice( + x, idx, y, indices_are_sorted, + unique_indices, mode=None, normalize_indices=True): + x = jnp.asarray(x) + y = jnp.asarray(y) + + # attempt to compute _scatter_update via lax.dynamic_update_slice(); return None if not possible. + idx = idx if isinstance(idx, tuple) else (idx,) + if not all(isinstance(i, int) for i in y.shape): + return None + if not all(isinstance(i, int) for i in x.shape): + return None + if len(idx) > x.ndim: + return None + if not all(_is_inbound_integer_index(i, size) or _is_simple_slice(i) + for i, size in zip(idx, x.shape)): + return None + idx = tuple(idx) + (slice(None),) * (x.ndim - len(idx)) + expected_shape = tuple(slc.indices(d)[1] - slc.indices(d)[0] + for slc, d in util.safe_zip(idx, x.shape) + if isinstance(slc, slice)) + try: + y = jnp.broadcast_to(y, expected_shape) + except ValueError: + return None + + new_shape = tuple(slc.indices(d)[1] - slc.indices(d)[0] if isinstance(slc, slice) else 1 + for slc, d in util.safe_zip(idx, x.shape)) + broadcast_indices = tuple(i for i, slc in enumerate(idx) if isinstance(slc, slice)) + y = lax.broadcast_in_dim(y, new_shape, broadcast_indices) + start_indices = tuple(slc.indices(d)[0] if isinstance(slc, slice) else slc + for slc, d in util.safe_zip(idx, x.shape)) + dtype = lax.dtype(x) + weak_type = dtypes.is_weakly_typed(x) + if dtype != lax.dtype(y) and dtype != dtypes.result_type(x, y): + # TODO(jakevdp): change this to an error after the deprecation period. + warnings.warn("scatter inputs have incompatible types: cannot safely cast " + f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. " + "In future JAX releases this will result in an error.", + FutureWarning) + out = lax.dynamic_update_slice(x, y.astype(dtype), start_indices) + return lax_internal._convert_element_type(out, dtype, weak_type) + def _scatter_update(x, idx, y, scatter_op, indices_are_sorted, unique_indices, mode=None, normalize_indices=True): """Helper for indexed updates.