From a2a05f7fe07a43767831b37e5c6b3d5ffb05beb9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 28 Oct 2021 15:21:06 -0700 Subject: [PATCH] [JAX] Update JAX users in preparation for a change that makes iteration over a JAX array return JAX arrays, instead of NumPy arrays. See https://github.com/google/jax/pull/8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 406230100 --- distrax/_src/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distrax/_src/distributions/distribution.py b/distrax/_src/distributions/distribution.py index a724628..d81d80b 100644 --- a/distrax/_src/distributions/distribution.py +++ b/distrax/_src/distributions/distribution.py @@ -280,7 +280,7 @@ def to_batch_shape_index( A new index that is only applied on the batch shape. """ try: - new_index = [x[index] for x in jnp.indices(batch_shape)] + new_index = [x[index] for x in np.indices(batch_shape)] return tuple(new_index) except IndexError as e: raise IndexError(f'Batch shape `{batch_shape}` not compatible with index '