Skip to content

Commit

Permalink
Lower scatter to dynamic_update_slice
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 31, 2023
1 parent 4bca098 commit a8ad117
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
6 changes: 6 additions & 0 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,12 @@ def set(self, values, *, indices_are_sorted=False, unique_indices=False,
See :mod:`jax.ops` for details.
"""
result = scatter._try_scatter_update_via_dynamic_slice(
self.array, self.index, values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
if result is not None:
return result
return scatter._scatter_update(self.array, self.index, values, lax.scatter,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
Expand Down
55 changes: 55 additions & 0 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,61 @@
Numeric = Union[Array, Scalar]


def _is_integer_index(idx: Any) -> bool:
return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_))

def _is_inbound_integer_index(idx: Any, size: int) -> bool:
return _is_integer_index(idx) and -size <= idx < size

def _is_simple_slice(idx: Any) -> bool:
return (isinstance(idx, slice) and
(idx.start is None or _is_integer_index(idx.start)) and
(idx.stop is None or _is_integer_index(idx.stop)) and
(idx.step is None or idx.step == 1))

def _try_scatter_update_via_dynamic_slice(
x, idx, y, indices_are_sorted,
unique_indices, mode=None, normalize_indices=True):
x = jnp.asarray(x)
y = jnp.asarray(y)

# attempt to compute _scatter_update via lax.dynamic_update_slice(); return None if not possible.
idx = idx if isinstance(idx, tuple) else (idx,)
if not all(isinstance(i, int) for i in y.shape):
return None
if not all(isinstance(i, int) for i in x.shape):
return None
if len(idx) > x.ndim:
return None
if not all(_is_inbound_integer_index(i, size) or _is_simple_slice(i)
for i, size in zip(idx, x.shape)):
return None
idx = tuple(idx) + (slice(None),) * (x.ndim - len(idx))
expected_shape = tuple(slc.indices(d)[1] - slc.indices(d)[0]
for slc, d in util.safe_zip(idx, x.shape)
if isinstance(slc, slice))
try:
y = jnp.broadcast_to(y, expected_shape)
except ValueError:
return None

new_shape = tuple(slc.indices(d)[1] - slc.indices(d)[0] if isinstance(slc, slice) else 1
for slc, d in util.safe_zip(idx, x.shape))
broadcast_indices = tuple(i for i, slc in enumerate(idx) if isinstance(slc, slice))
y = lax.broadcast_in_dim(y, new_shape, broadcast_indices)
start_indices = tuple(slc.indices(d)[0] if isinstance(slc, slice) else slc
for slc, d in util.safe_zip(idx, x.shape))
dtype = lax.dtype(x)
weak_type = dtypes.is_weakly_typed(x)
if dtype != lax.dtype(y) and dtype != dtypes.result_type(x, y):
# TODO(jakevdp): change this to an error after the deprecation period.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
"In future JAX releases this will result in an error.",
FutureWarning)
out = lax.dynamic_update_slice(x, y.astype(dtype), start_indices)
return lax_internal._convert_element_type(out, dtype, weak_type)

def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
unique_indices, mode=None, normalize_indices=True):
"""Helper for indexed updates.
Expand Down

0 comments on commit a8ad117

Please sign in to comment.