From 46297dccaf9db98370fb48bee43942374b4f7b45 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 4 Apr 2023 08:16:59 -0700 Subject: [PATCH] checkify: catch OOB errors in dynamic_slice This will allow checkify tests to continue working properly after #15377 --- jax/_src/checkify.py | 16 ++++++++++++++++ tests/checkify_test.py | 17 +++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 369de3845e7e..23267f6ce372 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -587,6 +587,22 @@ def isnan(x): error_checks[_prim] = functools.partial(nan_error_check, _prim) +def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes): + out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes) + + if OOBError not in enabled_errors: + return error, out + + operand_dims = np.array(operand.shape) + slice_sizes = np.array(slice_sizes) + start_indices = jnp.array(start_indices) + oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims) + + payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape) + error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "dynamic_slice", operand.shape, payload)) + return error, out +error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check + def gather_error_check(error, enabled_errors, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 8600270b41a3..57765268923d 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -181,6 +181,23 @@ def raises_oob(fn, idx, *expected_strs): raises_oob(multi_idx, (5, -9), "index 5", axis0_msg) raises_oob(multi_idx, ((0, 9), 0), "index 9", axis0_msg) + def test_dynamic_slice_oobs(self): + def raises_oob(fn, x, idx, *expected_strs): + err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, idx) + error_txt = err.get() + self.assertIsNotNone(error_txt) + self.assertStartsWith(error_txt, "out-of-bounds indexing") + for s in expected_strs: + self.assertIn(s, error_txt) + + x = jnp.ones((2, 3, 7)) + raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (2, 0, 0), 'index 2') + raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (-3, 0, 0), 'index -1') + raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 3, 0), 'index 3') + raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, -5, 0), 'index -2') + raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, 8), 'index 8') + raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, -10), 'index -3') + @jtu.sample_product(jit=[False, True]) def test_jit_ordering(self, jit): def f(x, i):