Skip to content

Commit

Permalink
Merge pull request #15390 from jakevdp:checkify-dynamic-slice
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 521790925
  • Loading branch information
jax authors committed Apr 4, 2023
2 parents 35bfdc6 + 46297dc commit 9bb3d86
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
16 changes: 16 additions & 0 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,22 @@ def isnan(x):
error_checks[_prim] = functools.partial(nan_error_check, _prim)


def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes):
out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes)

if OOBError not in enabled_errors:
return error, out

operand_dims = np.array(operand.shape)
slice_sizes = np.array(slice_sizes)
start_indices = jnp.array(start_indices)
oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims)

payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "dynamic_slice", operand.shape, payload))
return error, out
error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check

def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
Expand Down
17 changes: 17 additions & 0 deletions tests/checkify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,23 @@ def raises_oob(fn, idx, *expected_strs):
raises_oob(multi_idx, (5, -9), "index 5", axis0_msg)
raises_oob(multi_idx, ((0, 9), 0), "index 9", axis0_msg)

def test_dynamic_slice_oobs(self):
def raises_oob(fn, x, idx, *expected_strs):
err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, idx)
error_txt = err.get()
self.assertIsNotNone(error_txt)
self.assertStartsWith(error_txt, "out-of-bounds indexing")
for s in expected_strs:
self.assertIn(s, error_txt)

x = jnp.ones((2, 3, 7))
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (2, 0, 0), 'index 2')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (-3, 0, 0), 'index -1')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 3, 0), 'index 3')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, -5, 0), 'index -2')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, 8), 'index 8')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, -10), 'index -3')

@jtu.sample_product(jit=[False, True])
def test_jit_ordering(self, jit):
def f(x, i):
Expand Down

0 comments on commit 9bb3d86

Please sign in to comment.