Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
jnp.unravel_index: improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 4, 2022
1 parent 58320e2 commit 3c2d2b2
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4216,25 +4216,26 @@ def jnp_fun(a, c):
else:
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.parameters(
(0, (2, 1, 3)),
(5, (2, 1, 3)),
(0, ()),
(np.array([0, 1, 2]), (2, 2)),
(np.array([[[0, 1], [2, 3]]]), (2, 2)),
# regression test for https://github.com/google/jax/issues/10540
(np.arange(5), (201_996, 201_996)), # prod(shape) overflows int32.
)
def testUnravelIndex(self, flat_index, shape):
args_maker = lambda: (flat_index, shape)
np_fun = jtu.with_jax_dtype_defaults(np.unravel_index, use_defaults=not hasattr(flat_index, 'dtype'))
self._CheckAgainstNumpy(np_fun, jnp.unravel_index, args_maker)
self._CompileAndCheck(jnp.unravel_index, args_maker)

def testUnravelIndexOOB(self):
self.assertEqual(jnp.unravel_index(2, (2,)), (1,))
self.assertEqual(jnp.unravel_index(-2, (2, 1, 3,)), (1, 0, 1))
self.assertEqual(jnp.unravel_index(-3, (2,)), (0,))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idx={}".format(shape,
jtu.format_shape_dtype_string(idx_shape, dtype)),
"shape": shape, "idx_shape": idx_shape, "dtype": dtype}
for shape in nonempty_nonscalar_array_shapes
for dtype in int_dtypes
for idx_shape in all_shapes))
def testUnravelIndex(self, shape, idx_shape, dtype):
size = prod(shape)
rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3)

def np_fun(index, shape):
# Adjust out-of-bounds behavior to match jax's documented behavior.
index = np.clip(index, -size, size - 1)
index = np.where(index < 0, index + size, index)
return np.unravel_index(index, shape)
jnp_fun = jnp.unravel_index
args_maker = lambda: [rng(idx_shape, dtype), shape]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def testAstype(self):
rng = self.rng()
Expand Down

0 comments on commit 3c2d2b2

Please sign in to comment.