diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 59ccec0bbb57..35542919e14e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -808,22 +808,22 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'): _UNRAVEL_INDEX_DOC = """\ Unlike numpy's implementation of unravel_index, negative indices are accepted -and out-of-bounds indices are clipped. +and out-of-bounds indices are clipped into the valid range. """ @_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) def unravel_index(indices, shape): _check_arraylike("unravel_index", indices) - sizes = append(array(shape), 1) - cumulative_sizes = cumprod(sizes[::-1])[::-1] - total_size = cumulative_sizes[0] - # Clip so raveling and unraveling an oob index will not change the behavior - clipped_indices = clip(indices, -total_size, total_size - 1) - # Add enough trailing dims to avoid conflict with clipped_indices - cumulative_sizes = expand_dims(cumulative_sizes, range(1, 1 + _ndim(indices))) - clipped_indices = expand_dims(clipped_indices, axis=0) - idx = clipped_indices % cumulative_sizes[:-1] // cumulative_sizes[1:] - return tuple(idx) + shape = atleast_1d(shape) + if shape.ndim != 1: + raise ValueError("unravel_index: shape should be a scalar or 1D sequence.") + out_indices = [None] * len(shape) + for i, s in reversed(list(enumerate(shape))): + indices, out_indices[i] = divmod(indices, s) + oob_pos = indices > 0 + oob_neg = indices < -1 + return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) + for s, i in zip(shape, out_indices)) @_wraps(np.resize) @partial(jit, static_argnames=('new_shape',)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 9d270d3e4b75..2b6234d0447b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4216,22 +4216,26 @@ def jnp_fun(a, c): else: self._CompileAndCheck(jnp_fun, args_maker) - @parameterized.parameters( - (0, (2, 1, 3)), - (5, (2, 1, 3)), - (0, ()), - (np.array([0, 1, 2]), (2, 2)), - (np.array([[[0, 1], [2, 3]]]), (2, 2))) - def testUnravelIndex(self, flat_index, shape): - args_maker = lambda: (flat_index, shape) - np_fun = jtu.with_jax_dtype_defaults(np.unravel_index, use_defaults=not hasattr(flat_index, 'dtype')) - self._CheckAgainstNumpy(np_fun, jnp.unravel_index, args_maker) - self._CompileAndCheck(jnp.unravel_index, args_maker) - - def testUnravelIndexOOB(self): - self.assertEqual(jnp.unravel_index(2, (2,)), (1,)) - self.assertEqual(jnp.unravel_index(-2, (2, 1, 3,)), (1, 0, 1)) - self.assertEqual(jnp.unravel_index(-3, (2,)), (0,)) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idx={}".format(shape, + jtu.format_shape_dtype_string(idx_shape, dtype)), + "shape": shape, "idx_shape": idx_shape, "dtype": dtype} + for shape in nonempty_nonscalar_array_shapes + for dtype in int_dtypes + for idx_shape in all_shapes)) + def testUnravelIndex(self, shape, idx_shape, dtype): + size = prod(shape) + rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3) + + def np_fun(index, shape): + # Adjust out-of-bounds behavior to match jax's documented behavior. + index = np.clip(index, -size, size - 1) + index = np.where(index < 0, index + size, index) + return np.unravel_index(index, shape) + jnp_fun = jnp.unravel_index + args_maker = lambda: [rng(idx_shape, dtype), shape] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) def testAstype(self): rng = self.rng()