[JAX] Update JAX users in preparation for a change that makes iteration over a JAX array return JAX arrays, instead of NumPy arrays. #58
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[JAX] Update JAX users in preparation for a change that makes iteration over a JAX array return JAX arrays, instead of NumPy arrays.
See jax-ml/jax#8043 for context as to why we are making this change.
The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:
.tolist()
ornp.asarray(...)
when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed toset(...)
. In some instances, we can just callnumpy
functions instead ofjax.numpy
functions to build the array in the first place.numpy.testing.assert_array_equal
instead ofnumpy.testing.assert_equal
to compare JAX arrays.