Skip to content

Commit

Permalink
[JAX] Update JAX users in preparation for a change that makes iterati…
Browse files Browse the repository at this point in the history
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 400030119
  • Loading branch information
hawkinsp authored and JAX-CFD authors committed Sep 30, 2021
1 parent 9f815ed commit 3f64ede
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions jax_cfd/base/advection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _square_concentration(grid):

def _unit_velocity(grid, velocity_sign=1.):
ndim = grid.ndim
offsets = (jnp.eye(ndim) + jnp.ones([ndim, ndim])) / 2.
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
return tuple(
grids.GridArray(velocity_sign * jnp.ones(grid.shape) if ax == 0
else jnp.zeros(grid.shape), tuple(offset), grid)
Expand All @@ -52,7 +52,7 @@ def _unit_velocity(grid, velocity_sign=1.):

def _cos_velocity(grid):
ndim = grid.ndim
offsets = (jnp.eye(ndim) + jnp.ones([ndim, ndim])) / 2.
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
mesh = grid.mesh()
v = tuple(grids.GridArray(jnp.cos(mesh[i] * 2. * np.pi), tuple(offset), grid)
for i, offset in enumerate(offsets))
Expand Down
3 changes: 2 additions & 1 deletion jax_cfd/base/forcings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np


def _make_zero_velocity_field(grid):
ndim = grid.ndim
offsets = (jnp.eye(ndim) + jnp.ones([ndim, ndim])) / 2.
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
return tuple(
grids.GridArray(jnp.zeros(grid.shape), tuple(offset), grid)
for ax, offset in enumerate(offsets))
Expand Down

0 comments on commit 3f64ede

Please sign in to comment.