Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
Merge pull request jax-ml#10546 from jakevdp:unravel-indices
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 446553390
  • Loading branch information
jax authors committed May 4, 2022
2 parents a8c6742 + 3c2d2b2 commit 7297115
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 27 deletions.
22 changes: 11 additions & 11 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,22 +808,22 @@ def ravel_multi_index(multi_index, dims, mode='raise', order='C'):

_UNRAVEL_INDEX_DOC = """\
Unlike numpy's implementation of unravel_index, negative indices are accepted
and out-of-bounds indices are clipped.
and out-of-bounds indices are clipped into the valid range.
"""

@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
def unravel_index(indices, shape):
_check_arraylike("unravel_index", indices)
sizes = append(array(shape), 1)
cumulative_sizes = cumprod(sizes[::-1])[::-1]
total_size = cumulative_sizes[0]
# Clip so raveling and unraveling an oob index will not change the behavior
clipped_indices = clip(indices, -total_size, total_size - 1)
# Add enough trailing dims to avoid conflict with clipped_indices
cumulative_sizes = expand_dims(cumulative_sizes, range(1, 1 + _ndim(indices)))
clipped_indices = expand_dims(clipped_indices, axis=0)
idx = clipped_indices % cumulative_sizes[:-1] // cumulative_sizes[1:]
return tuple(idx)
shape = atleast_1d(shape)
if shape.ndim != 1:
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
out_indices = [None] * len(shape)
for i, s in reversed(list(enumerate(shape))):
indices, out_indices[i] = divmod(indices, s)
oob_pos = indices > 0
oob_neg = indices < -1
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
for s, i in zip(shape, out_indices))

@_wraps(np.resize)
@partial(jit, static_argnames=('new_shape',))
Expand Down
36 changes: 20 additions & 16 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4216,22 +4216,26 @@ def jnp_fun(a, c):
else:
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.parameters(
(0, (2, 1, 3)),
(5, (2, 1, 3)),
(0, ()),
(np.array([0, 1, 2]), (2, 2)),
(np.array([[[0, 1], [2, 3]]]), (2, 2)))
def testUnravelIndex(self, flat_index, shape):
args_maker = lambda: (flat_index, shape)
np_fun = jtu.with_jax_dtype_defaults(np.unravel_index, use_defaults=not hasattr(flat_index, 'dtype'))
self._CheckAgainstNumpy(np_fun, jnp.unravel_index, args_maker)
self._CompileAndCheck(jnp.unravel_index, args_maker)

def testUnravelIndexOOB(self):
self.assertEqual(jnp.unravel_index(2, (2,)), (1,))
self.assertEqual(jnp.unravel_index(-2, (2, 1, 3,)), (1, 0, 1))
self.assertEqual(jnp.unravel_index(-3, (2,)), (0,))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idx={}".format(shape,
jtu.format_shape_dtype_string(idx_shape, dtype)),
"shape": shape, "idx_shape": idx_shape, "dtype": dtype}
for shape in nonempty_nonscalar_array_shapes
for dtype in int_dtypes
for idx_shape in all_shapes))
def testUnravelIndex(self, shape, idx_shape, dtype):
size = prod(shape)
rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3)

def np_fun(index, shape):
# Adjust out-of-bounds behavior to match jax's documented behavior.
index = np.clip(index, -size, size - 1)
index = np.where(index < 0, index + size, index)
return np.unravel_index(index, shape)
jnp_fun = jnp.unravel_index
args_maker = lambda: [rng(idx_shape, dtype), shape]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def testAstype(self):
rng = self.rng()
Expand Down

0 comments on commit 7297115

Please sign in to comment.