Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Update JAX users in preparation for a change that makes iteration over a JAX array return JAX arrays, instead of NumPy arrays. #58

Merged
merged 1 commit into from
Sep 30, 2021

Conversation

copybara-service[bot]
Copy link

[JAX] Update JAX users in preparation for a change that makes iteration over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

  • Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call .tolist() or np.asarray(...) when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to set(...). In some instances, we can just call numpy functions instead of jax.numpy functions to build the array in the first place.
  • This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
  • We now need to use numpy.testing.assert_array_equal instead of numpy.testing.assert_equal to compare JAX arrays.

…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 400030119
@copybara-service copybara-service bot merged commit 3f64ede into main Sep 30, 2021
@copybara-service copybara-service bot deleted the test_399922762 branch September 30, 2021 21:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant