Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkify: catch OOB errors in dynamic_slice #15390

Merged
merged 1 commit into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,22 @@ def isnan(x):
error_checks[_prim] = functools.partial(nan_error_check, _prim)


def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes):
out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes)

if OOBError not in enabled_errors:
return error, out

operand_dims = np.array(operand.shape)
slice_sizes = np.array(slice_sizes)
start_indices = jnp.array(start_indices)
oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims)

payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "dynamic_slice", operand.shape, payload))
return error, out
error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check

def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
Expand Down
17 changes: 17 additions & 0 deletions tests/checkify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,23 @@ def raises_oob(fn, idx, *expected_strs):
raises_oob(multi_idx, (5, -9), "index 5", axis0_msg)
raises_oob(multi_idx, ((0, 9), 0), "index 9", axis0_msg)

def test_dynamic_slice_oobs(self):
def raises_oob(fn, x, idx, *expected_strs):
err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, idx)
error_txt = err.get()
self.assertIsNotNone(error_txt)
self.assertStartsWith(error_txt, "out-of-bounds indexing")
for s in expected_strs:
self.assertIn(s, error_txt)

x = jnp.ones((2, 3, 7))
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (2, 0, 0), 'index 2')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (-3, 0, 0), 'index -1')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Negative indices are a bit weird: because lax.dynamic_slice normalizes negative indices before passing them to dynamic_slice_p.bind(), the index mentioned in the message doesn't match the index passed to dynamic_slice. But this is the same behavior as in gather (e.g. line 168 above), so I think it's working as expected.

raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 3, 0), 'index 3')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, -5, 0), 'index -2')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, 8), 'index 8')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, -10), 'index -3')

@jtu.sample_product(jit=[False, True])
def test_jit_ordering(self, jit):
def f(x, i):
Expand Down