Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeviceArray.__iter__ returns DeviceArrays, without host sync #8043

Merged
merged 1 commit into from
Nov 1, 2021

Conversation

hawkinsp
Copy link
Collaborator

@hawkinsp hawkinsp commented Sep 29, 2021

This is a clone of
#3821

rebased on top of current main.

Original description:

cross-ref #1583 #3330

The current (before this PR) implementation of DeviceArray.__iter__ fom #965 pulls the array back to the host and does the slicing on the resulting NumPy array, effectively:

# current implementation
class DeviceArray:
  def __iter__(self):
    return iter(self._value)  # self._value is the cached ndarray version of the array

There are three issues we want to address here:

  1. list(device_array) shouldn't return a list of numpy.ndarays, and instead should return a list of DeviceArrays, regardless of how those are computed (this issue came up in List comprehensions and for-loops on xla.DeviceArrays return numpy.ndarrays #1583 and in some JAX user chats recently);
  2. key, subkey = random.split(key) in op-by-op mode shouldn't incur device synchronization, but it does curently because the we effectively call random.split(key).__iter__() (this issue came up in Unexpected / unintuitive behaviour of tuple(DeviceArray) #3330 and in JAX chats);
  3. from Make DeviceArray.__iter__ and __reversed__ forward to _value. #965, we don't want expressions like list(device_array) or list(jnp.arange(10000)) to be slow to evaluate, where before Make DeviceArray.__iter__ and __reversed__ forward to _value. #965 there could be a several-second compilation time and relatively slow (~100ms) execution time.

To address (1) we could either keep the bounce via NumPy and add some device_puts, or keep all the operations on the device. To address (2) we want to keep all the operations on the device. Currently that means we need to compile and execute XLA computations (rather than just doing something in the runtime).

To address (3), i.e. to keep things performant, most importantly we don't want to incur big compilation times, which basically means we don't want to compile computations with large output arity. We also don't want to dispatch lots of XLA computations, to keep execution time reasonable. We can balance the two by chunking things, so that for chunk size C and array size N we end up compiling O(1) computations, none having output arity more than C, and dispatching about 2 N / C computations (for each chunk, one to slice it out of the original array and one to explode it).

I timed the current and new implementation this way:

import time
import jax.numpy as jnp
from jax.interpreters import xla

def clear_caches():
  xla.xla_primitive_callable.cache_clear()
  xla._xla_callable.cache_clear()

def time_things(x):
  # time to grab the first element, nothing compiled
  times = []
  for _ in range(100):
    clear_caches()
    tic = time.time()
    next(iter(x))
    times.append(time.time() - tic)
  print(f"time to grab first element, uncompiled: {sum(times) / len(times) * 1000} ms")

  # time to grab first element, compiled
  times = []
  for _ in range(100):
    tic = time.time()
    next(iter(x))
    times.append(time.time() - tic)
  print(f"time to grab first element, compiled: {sum(times) / len(times) * 1000} ms")

  # time to grab the whole thing, nothing compiled
  times = []
  for _ in range(100):
    clear_caches()
    tic = time.time()
    list(x)
    times.append(time.time() - tic)
  print(f"total time to grab all elements, uncompiled: {sum(times) / len(times) * 1000} ms")
  print(f"average time to grab each element, uncompiled: {sum(times) / len(times) / x.shape[0] * 1000} ms")

  # time to grab the whole thing, compiled
  times = []
  for _ in range(100):
    tic = time.time()
    list(x)
    times.append(time.time() - tic)
  print(f"total time to grab all elements, compiled: {sum(times) / len(times) * 1000} ms")
  print(f"average time to grab each element, compiled: {sum(times) / len(times) / x.shape[0] * 1000} ms")

small = jnp.arange(10) + 1
big = jnp.arange(10000) + 1

Notice we are not calling block_until_ready and instead just looking at dispatch times. The numbers that come out are pretty noisy.

For the current implementation, on a TPUv3 internal colab we see this:

time_things(small)
time to grab first element, uncompiled: 0.001392364501953125 ms
time to grab first element, compiled: 0.001277923583984375 ms
total time to grab all elements, uncompiled: 0.003528594970703125 ms
average time to grab each element, uncompiled: 0.0003528594970703125 ms
total time to grab all elements, compiled: 0.0027894973754882812 ms
average time to grab each element, compiled: 0.0002789497375488281 ms

time_things(big)
time to grab first element, uncompiled: 0.0012826919555664062 ms
time to grab first element, compiled: 0.0011134147644042969 ms
total time to grab all elements, uncompiled: 0.7117271423339844 ms
average time to grab each element, uncompiled: 7.117271423339843e-05 ms
total time to grab all elements, compiled: 0.7156133651733398 ms
average time to grab each element, compiled: 7.156133651733398e-05 ms

For the new implementation with chunk_size=100, on a TPUv3 internal colab we see this:

time_things(small)
time to grab first element, uncompiled: 12.47842788696289 ms
time to grab first element, compiled: 0.1662755012512207 ms
total time to grab all elements, uncompiled: 12.541697025299072 ms
average time to grab each element, uncompiled: 1.2541697025299074 ms
total time to grab all elements, compiled: 0.16188859939575195 ms
average time to grab each element, compiled: 0.016188859939575195 ms

time_things(big)
time to grab first element, uncompiled: 88.96037578582764 ms
time to grab first element, compiled: 0.809943675994873 ms
total time to grab all elements, uncompiled: 182.55109786987305 ms
average time to grab each element, uncompiled: 0.018255109786987307 ms
total time to grab all elements, compiled: 80.00826835632324 ms
average time to grab each element, compiled: 0.008000826835632323 ms

I think the ballparks here seem acceptable, and so meet desideratum (3) while avoiding the host sync per (2). Comparing to #965, we avoid any multi-second compile times even though we end up paying 80ms to compute list(big) rather than <1ms. I suspect these times will get better when we improve jit dispatch time, since a significant fraction of the trace is spent on Python overhead (anything that's not an Execute bar):

image

Here's one other benchmark:

from jax import random 
import time

def time_things2():
  key = random.PRNGKey(0)
  key, _ = random.split(key)
  
  tic = time.time()
  for _ in range(1000):
    key, subkey = random.split(key)
  toc = time.time()
  print(f"{(toc - tic)} ms")

With the current implementation, on TPU we get:

time_things2()
0.27765631675720215 ms

With the new implementation, on TPU we get:

time_things2()
0.20100140571594238 ms

So perhaps we're saving something from avoiding the sync here, even though there's no other work going on here.

fixes #1583

@google-cla
Copy link

google-cla bot commented Sep 29, 2021

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added the cla: no label Sep 29, 2021
@hawkinsp hawkinsp added cla: yes pull ready Ready for copybara import and testing and removed cla: no labels Sep 29, 2021
@google-cla
Copy link

google-cla bot commented Sep 29, 2021

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla
Copy link

google-cla bot commented Sep 29, 2021

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla
Copy link

google-cla bot commented Sep 30, 2021

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added cla: no and removed cla: yes labels Sep 30, 2021
@hawkinsp hawkinsp added cla: yes and removed cla: no labels Sep 30, 2021
@hawkinsp hawkinsp changed the title make iter(DeviceArray) return DeviceArrays w/o sync DeviceArray.__iter__ returns DeviceArrays, without host sync Sep 30, 2021
@hawkinsp hawkinsp marked this pull request as ready for review September 30, 2021 13:38
@hawkinsp hawkinsp requested a review from mattjj September 30, 2021 13:38
copybara-service bot pushed a commit to google/jax-cfd that referenced this pull request Sep 30, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 399922762
copybara-service bot pushed a commit to google/flax that referenced this pull request Sep 30, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 399922724
copybara-service bot pushed a commit that referenced this pull request Oct 27, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405780628
copybara-service bot pushed a commit that referenced this pull request Oct 27, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405780628
@cgarciae
Copy link
Collaborator

This is very nice! split was causing headaches for me so I had to create my own wrapper that mimicked split but returned a DeviceArray list so you could destructure it safely.

copybara-service bot pushed a commit that referenced this pull request Oct 27, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405780628
copybara-service bot pushed a commit that referenced this pull request Oct 27, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405780628
copybara-service bot pushed a commit that referenced this pull request Oct 27, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405780628
copybara-service bot pushed a commit that referenced this pull request Oct 27, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405780628
copybara-service bot pushed a commit that referenced this pull request Oct 27, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405780628
copybara-service bot pushed a commit that referenced this pull request Oct 27, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405780628
copybara-service bot pushed a commit that referenced this pull request Oct 27, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays.

PiperOrigin-RevId: 405995198
copybara-service bot pushed a commit to google-deepmind/distrax that referenced this pull request Oct 28, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406150402
copybara-service bot pushed a commit to google-deepmind/distrax that referenced this pull request Oct 28, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406150402
copybara-service bot pushed a commit to google-deepmind/distrax that referenced this pull request Oct 28, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406230100
copybara-service bot pushed a commit that referenced this pull request Oct 28, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406150403
copybara-service bot pushed a commit that referenced this pull request Oct 28, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406150403
copybara-service bot pushed a commit that referenced this pull request Oct 28, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See #8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406247725
@copybara-service copybara-service bot merged commit 335857b into jax-ml:main Nov 1, 2021
marcocuturi pushed a commit to google-research/ott that referenced this pull request Nov 7, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 407070913
kentonl pushed a commit to google-research/language that referenced this pull request Nov 13, 2021
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406860111
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 11, 2021

A note: this change seems to have broken matplotlib's histogram function when a DeviceArray is passed to it (with most recent matplotlib and JAX releases):

import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as np

fix, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].hist(np.arange(10))
ax[1].hist(jnp.arange(10));

download-3

I believe this should be considered a matplotlib bug (see discussion in #8899 (comment)) but the new behavior here is what resulted in the problematic code path.

@soraros
Copy link

soraros commented Dec 11, 2021

Maybe this deserves a the-sharp-bits entry?

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 3, 2022

matplotlib issue fixed in matplotlib/matplotlib#22016

michalk8 pushed a commit to ott-jax/ott that referenced this pull request Jun 27, 2024
…on over a JAX array return JAX arrays, instead of NumPy arrays.

See jax-ml/jax#8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 407070913
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

List comprehensions and for-loops on xla.DeviceArrays return numpy.ndarrays
6 participants