Skip to content

Commit

Permalink
[JAX] Update JAX users in preparation for a change that makes iterati…
Browse files Browse the repository at this point in the history
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 407070913
  • Loading branch information
OTT-JAX authors authored and marcocuturi committed Nov 7, 2021
1 parent 4e08d5d commit 41092c5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
"""
if (grid_size is not None and x is not None and num_a is not None and
grid_dimension is not None):
self.grid_size = grid_size
self.grid_size = tuple(map(int, grid_size))
self.x = x
self.num_a = num_a
self.grid_dimension = grid_dimension
Expand All @@ -90,7 +90,7 @@ def __init__(
self.num_a = np.prod(np.array(self.grid_size))
self.grid_dimension = len(self.x)
elif grid_size is not None:
self.grid_size = grid_size
self.grid_size = tuple(map(int, grid_size))
self.x = tuple([jnp.linspace(0, 1, n) for n in self.grid_size])
self.num_a = np.prod(np.array(grid_size))
self.grid_dimension = len(self.grid_size)
Expand Down

0 comments on commit 41092c5

Please sign in to comment.