From 58320e2c89ae00c6af70a0ae26107fb26f6d1ba8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 May 2022 12:12:06 -0700 Subject: [PATCH] jnp.unravel_index: avoid overflow for large dimension sizes --- jax/_src/numpy/lax_numpy.py | 22 +++++++++++----------- tests/lax_numpy_test.py | 5 ++++- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 59ccec0bbb57..35542919e14e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -808,22 +808,22 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'): _UNRAVEL_INDEX_DOC = """\ Unlike numpy's implementation of unravel_index, negative indices are accepted -and out-of-bounds indices are clipped. +and out-of-bounds indices are clipped into the valid range. """ @_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) def unravel_index(indices, shape): _check_arraylike("unravel_index", indices) - sizes = append(array(shape), 1) - cumulative_sizes = cumprod(sizes[::-1])[::-1] - total_size = cumulative_sizes[0] - # Clip so raveling and unraveling an oob index will not change the behavior - clipped_indices = clip(indices, -total_size, total_size - 1) - # Add enough trailing dims to avoid conflict with clipped_indices - cumulative_sizes = expand_dims(cumulative_sizes, range(1, 1 + _ndim(indices))) - clipped_indices = expand_dims(clipped_indices, axis=0) - idx = clipped_indices % cumulative_sizes[:-1] // cumulative_sizes[1:] - return tuple(idx) + shape = atleast_1d(shape) + if shape.ndim != 1: + raise ValueError("unravel_index: shape should be a scalar or 1D sequence.") + out_indices = [None] * len(shape) + for i, s in reversed(list(enumerate(shape))): + indices, out_indices[i] = divmod(indices, s) + oob_pos = indices > 0 + oob_neg = indices < -1 + return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) + for s, i in zip(shape, out_indices)) @_wraps(np.resize) @partial(jit, static_argnames=('new_shape',)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 9d270d3e4b75..4ab6ab2a9982 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4221,7 +4221,10 @@ def jnp_fun(a, c): (5, (2, 1, 3)), (0, ()), (np.array([0, 1, 2]), (2, 2)), - (np.array([[[0, 1], [2, 3]]]), (2, 2))) + (np.array([[[0, 1], [2, 3]]]), (2, 2)), + # regression test for https://github.com/google/jax/issues/10540 + (np.arange(5), (201_996, 201_996)), # prod(shape) overflows int32. + ) def testUnravelIndex(self, flat_index, shape): args_maker = lambda: (flat_index, shape) np_fun = jtu.with_jax_dtype_defaults(np.unravel_index, use_defaults=not hasattr(flat_index, 'dtype'))