Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeviceArray.__iter__ returns DeviceArrays, without host sync #3821

Closed
wants to merge 1 commit into from

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Jul 22, 2020

cross-ref #1583 #3330

The current (before this PR) implementation of DeviceArray.__iter__ fom #965 pulls the array back to the host and does the slicing on the resulting NumPy array, effectively:

# current implementation
class DeviceArray:
  def __iter__(self):
    return iter(self._value)  # self._value is the cached ndarray version of the array

There are three issues we want to address here:

  1. list(device_array) shouldn't return a list of numpy.ndarays, and instead should return a list of DeviceArrays, regardless of how those are computed (this issue came up in List comprehensions and for-loops on xla.DeviceArrays return numpy.ndarrays #1583 and in some JAX user chats recently);
  2. key, subkey = random.split(key) in op-by-op mode shouldn't incur device synchronization, but it does curently because the we effectively call random.split(key).__iter__() (this issue came up in Unexpected / unintuitive behaviour of tuple(DeviceArray) #3330 and in JAX chats);
  3. from Make DeviceArray.__iter__ and __reversed__ forward to _value. #965, we don't want expressions like list(device_array) or list(jnp.arange(10000)) to be slow to evaluate, where before Make DeviceArray.__iter__ and __reversed__ forward to _value. #965 there could be a several-second compilation time and relatively slow (~100ms) execution time.

To address (1) we could either keep the bounce via NumPy and add some device_puts, or keep all the operations on the device. To address (2) we want to keep all the operations on the device. Currently that means we need to compile and execute XLA computations (rather than just doing something in the runtime).

To address (3), i.e. to keep things performant, most importantly we don't want to incur big compilation times, which basically means we don't want to compile computations with large output arity. We also don't want to dispatch lots of XLA computations, to keep execution time reasonable. We can balance the two by chunking things, so that for chunk size C and array size N we end up compiling O(1) computations, none having output arity more than C, and dispatching about 2 N / C computations (for each chunk, one to slice it out of the original array and one to explode it).

I timed the current and new implementation this way:

import time
import jax.numpy as jnp
from jax.interpreters import xla

def clear_caches():
  xla.xla_primitive_callable.cache_clear()
  xla._xla_callable.cache_clear()

def time_things(x):
  # time to grab the first element, nothing compiled
  times = []
  for _ in range(100):
    clear_caches()
    tic = time.time()
    next(iter(x))
    times.append(time.time() - tic)
  print(f"time to grab first element, uncompiled: {sum(times) / len(times) * 1000} ms")

  # time to grab first element, compiled
  times = []
  for _ in range(100):
    tic = time.time()
    next(iter(x))
    times.append(time.time() - tic)
  print(f"time to grab first element, compiled: {sum(times) / len(times) * 1000} ms")

  # time to grab the whole thing, nothing compiled
  times = []
  for _ in range(100):
    clear_caches()
    tic = time.time()
    list(x)
    times.append(time.time() - tic)
  print(f"total time to grab all elements, uncompiled: {sum(times) / len(times) * 1000} ms")
  print(f"average time to grab each element, uncompiled: {sum(times) / len(times) / x.shape[0] * 1000} ms")

  # time to grab the whole thing, compiled
  times = []
  for _ in range(100):
    tic = time.time()
    list(x)
    times.append(time.time() - tic)
  print(f"total time to grab all elements, compiled: {sum(times) / len(times) * 1000} ms")
  print(f"average time to grab each element, compiled: {sum(times) / len(times) / x.shape[0] * 1000} ms")

small = jnp.arange(10) + 1
big = jnp.arange(10000) + 1

Notice we are not calling block_until_ready and instead just looking at dispatch times. The numbers that come out are pretty noisy.

For the current implementation, on a TPUv3 internal colab we see this:

time_things(small)
time to grab first element, uncompiled: 0.001392364501953125 ms
time to grab first element, compiled: 0.001277923583984375 ms
total time to grab all elements, uncompiled: 0.003528594970703125 ms
average time to grab each element, uncompiled: 0.0003528594970703125 ms
total time to grab all elements, compiled: 0.0027894973754882812 ms
average time to grab each element, compiled: 0.0002789497375488281 ms

time_things(big)
time to grab first element, uncompiled: 0.0012826919555664062 ms
time to grab first element, compiled: 0.0011134147644042969 ms
total time to grab all elements, uncompiled: 0.7117271423339844 ms
average time to grab each element, uncompiled: 7.117271423339843e-05 ms
total time to grab all elements, compiled: 0.7156133651733398 ms
average time to grab each element, compiled: 7.156133651733398e-05 ms

For the new implementation with chunk_size=100, on a TPUv3 internal colab we see this:

time_things(small)
time to grab first element, uncompiled: 12.47842788696289 ms
time to grab first element, compiled: 0.1662755012512207 ms
total time to grab all elements, uncompiled: 12.541697025299072 ms
average time to grab each element, uncompiled: 1.2541697025299074 ms
total time to grab all elements, compiled: 0.16188859939575195 ms
average time to grab each element, compiled: 0.016188859939575195 ms

time_things(big)
time to grab first element, uncompiled: 88.96037578582764 ms
time to grab first element, compiled: 0.809943675994873 ms
total time to grab all elements, uncompiled: 182.55109786987305 ms
average time to grab each element, uncompiled: 0.018255109786987307 ms
total time to grab all elements, compiled: 80.00826835632324 ms
average time to grab each element, compiled: 0.008000826835632323 ms

I think the ballparks here seem acceptable, and so meet desideratum (3) while avoiding the host sync per (2). Comparing to #965, we avoid any multi-second compile times even though we end up paying 80ms to compute list(big) rather than <1ms. I suspect these times will get better when we improve jit dispatch time, since a significant fraction of the trace is spent on Python overhead (anything that's not an Execute bar):

image

Here's one other benchmark:

from jax import random 
import time

def time_things2():
  key = random.PRNGKey(0)
  key, _ = random.split(key)
  
  tic = time.time()
  for _ in range(1000):
    key, subkey = random.split(key)
  toc = time.time()
  print(f"{(toc - tic)} ms")

With the current implementation, on TPU we get:

time_things2()
0.27765631675720215 ms

With the new implementation, on TPU we get:

time_things2()
0.20100140571594238 ms

So perhaps we're saving something from avoiding the sync here, even though there's no other work going on here.

fixes #1583

jax/interpreters/xla.py Outdated Show resolved Hide resolved
@mattjj mattjj requested a review from hawkinsp July 23, 2020 02:19
@mattjj mattjj marked this pull request as ready for review July 23, 2020 02:20
@mattjj mattjj changed the title sketch alternative DeviceArray.__iter__ impl DeviceArray.__iter__ returns DeviceArrays, without host sync Jul 23, 2020
return self._value.__iter__()
device = self.device_buffer.device()
if device is None or device.platform == 'cpu':
return iter(self._value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this still returns NumPy arrays, I suspect that might be undesirable because of, e.g., different promotion semantics.

However, it might make sense in the short term until some of the overheads improve.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. We should revise this case so that it returns CPU DeviceArrays, maybe.

Copy link
Collaborator Author

@mattjj mattjj Nov 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CPU, with return iter(self._value), running the big timing script in the OP using the big array, we get:

time to grab first element, uncompiled: 0.002505779266357422 ms
time to grab first element, compiled: 0.0018811225891113281 ms
total time to grab all elements, uncompiled: 0.696418285369873 ms
average time to grab each element, uncompiled: 6.96418285369873e-05 ms
total time to grab all elements, compiled: 0.6600356101989746 ms
average time to grab each element, compiled: 6.600356101989746e-05 ms

If we instead keep doing the slicing in NumPy but device_put each result, i.e. we do something like this:

      device = self.device_buffer.device()
      if device is None or device.platform == 'cpu':
        # do the slicing in NumPy for better performance
        device = device or jax.devices('cpu')
        aval = ShapedArray(self.aval.shape[1:], self.aval.dtype,
                           self.aval.weak_type)
        lexpr = lazy.array(aval.shape)
        return (make_device_array(aval, device, lexpr, *device_put(x, device))
                for x in self._value)

Then we get this timing:

time to grab first element, uncompiled: 0.01703500747680664 ms
time to grab first element, compiled: 0.015168190002441406 ms
total time to grab all elements, uncompiled: 79.03305768966675 ms
average time to grab each element, uncompiled: 0.007903305768966674 ms
total time to grab all elements, compiled: 79.05462265014648 ms
average time to grab each element, compiled: 0.007905462265014648 ms

So a 100x slowdown to return DeviceArrays rather than ndarrays on CPU. Perhaps we'll be able to speed that up soon, but I wanted to bounce this off you anyway. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated timings, being slightly less dumb:

time to grab first element, uncompiled: 0.01026153564453125 ms
time to grab first element, compiled: 0.009059906005859375 ms
total time to grab all elements, uncompiled: 42.66105651855469 ms
average time to grab each element, uncompiled: 0.004266105651855468 ms
total time to grab all elements, compiled: 42.526516914367676 ms
average time to grab each element, compiled: 0.004252651691436768 ms

Copy link
Collaborator Author

@mattjj mattjj Apr 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New times, after jit dispatch time upgrades in jaxlib==0.1.65 (remember these are noisy!).

With special treatment of CPU (which is part of what this comment thread was about):

time to grab first element, uncompiled: 0.013675689697265625 ms
time to grab first element, compiled: 0.012180805206298828 ms
total time to grab all elements, uncompiled: 38.03548336029053 ms
average time to grab each element, uncompiled: 0.003803548336029053 ms
total time to grab all elements, compiled: 37.87814140319824 ms
average time to grab each element, compiled: 0.0037878141403198244 ms

Without special treatment of CPU:

time to grab first element, uncompiled: 0.013167858123779297 ms
time to grab first element, compiled: 0.012054443359375 ms
total time to grab all elements, uncompiled: 37.4117112159729 ms
average time to grab each element, uncompiled: 0.00374117112159729 ms
total time to grab all elements, compiled: 37.27412462234497 ms
average time to grab each element, compiled: 0.0037274124622344975 ms

So, no need for special treatment of CPU now!

@jekbradbury
Copy link
Contributor

I think this is worth merging?

@mattjj mattjj force-pushed the issue1583 branch 2 times, most recently from d8510c2 to 4f573c1 Compare November 17, 2020 00:12
@mattjj mattjj added the pull ready Ready for copybara import and testing label Nov 24, 2020
copybara-service bot pushed a commit to google/flax that referenced this pull request Dec 2, 2020
jax-ml/jax#3821.

The idea of the JAX change is in part that DeviceArray.__iter__ should return
DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is
performance: it avoids a host sync. A secondary motivation is type consistency.

However, that caused this line of Flax example code to trigger a NumPy bug,
discussed in this thread:
jax-ml/jax#620 (comment)
Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of
length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a
tuple) rather than as an array, leading to an error like "IndexError: too many
indices for array". The workaround employed here is to write x[i, ...] instead
of x[i], which bypasses the NumPy bug.

PiperOrigin-RevId: 345140147
copybara-service bot pushed a commit to google/flax that referenced this pull request Dec 2, 2020
jax-ml/jax#3821.

The idea of the JAX change is in part that DeviceArray.__iter__ should return
DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is
performance: it avoids a host sync. A secondary motivation is type consistency.

However, that caused this line of Flax example code to trigger a NumPy bug,
discussed in this thread:
jax-ml/jax#620 (comment)
Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of
length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a
tuple) rather than as an array, leading to an error like "IndexError: too many
indices for array". The workaround employed here is to write x[i, ...] instead
of x[i], which bypasses the NumPy bug.

PiperOrigin-RevId: 345140147
copybara-service bot pushed a commit to google/flax that referenced this pull request Dec 2, 2020
jax-ml/jax#3821.

The idea of the JAX change is in part that DeviceArray.__iter__ should return
DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is
performance: it avoids a host sync. A secondary motivation is type consistency.

However, that caused this line of Flax example code to trigger a NumPy bug,
discussed in this thread:
jax-ml/jax#620 (comment)
Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of
length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a
tuple) rather than as an array, leading to an error like "IndexError: too many
indices for array". The workaround employed here is to write x[i, ...] instead
of x[i], which bypasses the NumPy bug.

PiperOrigin-RevId: 345160314
copybara-service bot pushed a commit to google-deepmind/dm-haiku that referenced this pull request Dec 2, 2020
Before JAX's #3821, DeviceArray.__iter__ returned numpy.ndarray instances. As a
result, calling np.testing.assert_equal would do an array equality check:

https://github.com/numpy/numpy/blob/92ebe1e9a6aeb47a881a1226b08218175776f9ea/numpy/testing/_private/utils.py#L341

After #3821, DeviceArray.__iter__ returns DeviceArray instances, and though
these expose a __array__ method, the numpy logic linked above doesn't trigger.
Ultimately that leads to an error like "ValueError: The truth value of an array
with more than one element is ambiguous. Use a.any() or a.all()".

The fix here is to call np.testing.assert_array_equal directly, rather than to
rely on the np.testing.assert_equal wrapper to call it.

PiperOrigin-RevId: 345232915
Change-Id: I4506314c27c051330bace980c8dde34e91034a40
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 25, 2021

I'm trying to revive this... one issue is that the conditional implementation of __hash__ here falls afoul of the change in #7347

I don't think it's possible to have 0-dimensional arrays be hashable without re-breaking the issue fixed in #7347.

@mattjj
Copy link
Collaborator Author

mattjj commented Oct 16, 2021

Closing in favor of #8043

@mattjj mattjj closed this Oct 16, 2021
@mattjj mattjj deleted the issue1583 branch October 16, 2021 00:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

List comprehensions and for-loops on xla.DeviceArrays return numpy.ndarrays
5 participants