-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeviceArray.__iter__ returns DeviceArrays, without host sync #8043
Conversation
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
…on over a JAX array return JAX arrays, instead of NumPy arrays. See jax-ml/jax#8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 399922762
…on over a JAX array return JAX arrays, instead of NumPy arrays. See jax-ml/jax#8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 399922724
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. PiperOrigin-RevId: 405780628
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. PiperOrigin-RevId: 405780628
This is very nice! |
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. PiperOrigin-RevId: 405780628
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. PiperOrigin-RevId: 405780628
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. PiperOrigin-RevId: 405780628
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. PiperOrigin-RevId: 405780628
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. PiperOrigin-RevId: 405780628
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. PiperOrigin-RevId: 405780628
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. PiperOrigin-RevId: 405995198
…on over a JAX array return JAX arrays, instead of NumPy arrays. See jax-ml/jax#8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 406150402
…on over a JAX array return JAX arrays, instead of NumPy arrays. See jax-ml/jax#8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 406150402
…on over a JAX array return JAX arrays, instead of NumPy arrays. See jax-ml/jax#8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 406230100
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 406150403
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 406150403
…on over a JAX array return JAX arrays, instead of NumPy arrays. See #8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 406247725
…on over a JAX array return JAX arrays, instead of NumPy arrays. See jax-ml/jax#8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 407070913
…on over a JAX array return JAX arrays, instead of NumPy arrays. See jax-ml/jax#8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 406860111
A note: this change seems to have broken matplotlib's histogram function when a import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as np
fix, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].hist(np.arange(10))
ax[1].hist(jnp.arange(10)); I believe this should be considered a matplotlib bug (see discussion in #8899 (comment)) but the new behavior here is what resulted in the problematic code path. |
Maybe this deserves a the-sharp-bits entry? |
matplotlib issue fixed in matplotlib/matplotlib#22016 |
…on over a JAX array return JAX arrays, instead of NumPy arrays. See jax-ml/jax#8043 for context as to why we are making this change. The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular: * Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place. * This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries. * We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays. PiperOrigin-RevId: 407070913
This is a clone of
#3821
rebased on top of current
main
.Original description:
cross-ref #1583 #3330
The current (before this PR) implementation of
DeviceArray.__iter__
fom #965 pulls the array back to the host and does the slicing on the resulting NumPy array, effectively:There are three issues we want to address here:
list(device_array)
shouldn't return a list ofnumpy.ndaray
s, and instead should return a list ofDeviceArray
s, regardless of how those are computed (this issue came up in List comprehensions and for-loops on xla.DeviceArrays return numpy.ndarrays #1583 and in some JAX user chats recently);key, subkey = random.split(key)
in op-by-op mode shouldn't incur device synchronization, but it does curently because the we effectively callrandom.split(key).__iter__()
(this issue came up in Unexpected / unintuitive behaviour of tuple(DeviceArray) #3330 and in JAX chats);list(device_array)
orlist(jnp.arange(10000))
to be slow to evaluate, where before Make DeviceArray.__iter__ and __reversed__ forward to _value. #965 there could be a several-second compilation time and relatively slow (~100ms) execution time.To address (1) we could either keep the bounce via NumPy and add some device_puts, or keep all the operations on the device. To address (2) we want to keep all the operations on the device. Currently that means we need to compile and execute XLA computations (rather than just doing something in the runtime).
To address (3), i.e. to keep things performant, most importantly we don't want to incur big compilation times, which basically means we don't want to compile computations with large output arity. We also don't want to dispatch lots of XLA computations, to keep execution time reasonable. We can balance the two by chunking things, so that for chunk size C and array size N we end up compiling O(1) computations, none having output arity more than C, and dispatching about 2 N / C computations (for each chunk, one to slice it out of the original array and one to explode it).
I timed the current and new implementation this way:
Notice we are not calling
block_until_ready
and instead just looking at dispatch times. The numbers that come out are pretty noisy.For the current implementation, on a TPUv3 internal colab we see this:
For the new implementation with chunk_size=100, on a TPUv3 internal colab we see this:
I think the ballparks here seem acceptable, and so meet desideratum (3) while avoiding the host sync per (2). Comparing to #965, we avoid any multi-second compile times even though we end up paying 80ms to compute
list(big)
rather than <1ms. I suspect these times will get better when we improvejit
dispatch time, since a significant fraction of the trace is spent on Python overhead (anything that's not an Execute bar):Here's one other benchmark:
With the current implementation, on TPU we get:
With the new implementation, on TPU we get:
So perhaps we're saving something from avoiding the sync here, even though there's no other work going on here.
fixes #1583