-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeviceArray.__iter__ returns DeviceArrays, without host sync #3821
Conversation
jax/interpreters/xla.py
Outdated
return self._value.__iter__() | ||
device = self.device_buffer.device() | ||
if device is None or device.platform == 'cpu': | ||
return iter(self._value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this still returns NumPy arrays, I suspect that might be undesirable because of, e.g., different promotion semantics.
However, it might make sense in the short term until some of the overheads improve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good point. We should revise this case so that it returns CPU DeviceArrays, maybe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On CPU, with return iter(self._value)
, running the big timing script in the OP using the big
array, we get:
time to grab first element, uncompiled: 0.002505779266357422 ms
time to grab first element, compiled: 0.0018811225891113281 ms
total time to grab all elements, uncompiled: 0.696418285369873 ms
average time to grab each element, uncompiled: 6.96418285369873e-05 ms
total time to grab all elements, compiled: 0.6600356101989746 ms
average time to grab each element, compiled: 6.600356101989746e-05 ms
If we instead keep doing the slicing in NumPy but device_put each result, i.e. we do something like this:
device = self.device_buffer.device()
if device is None or device.platform == 'cpu':
# do the slicing in NumPy for better performance
device = device or jax.devices('cpu')
aval = ShapedArray(self.aval.shape[1:], self.aval.dtype,
self.aval.weak_type)
lexpr = lazy.array(aval.shape)
return (make_device_array(aval, device, lexpr, *device_put(x, device))
for x in self._value)
Then we get this timing:
time to grab first element, uncompiled: 0.01703500747680664 ms
time to grab first element, compiled: 0.015168190002441406 ms
total time to grab all elements, uncompiled: 79.03305768966675 ms
average time to grab each element, uncompiled: 0.007903305768966674 ms
total time to grab all elements, compiled: 79.05462265014648 ms
average time to grab each element, compiled: 0.007905462265014648 ms
So a 100x slowdown to return DeviceArrays rather than ndarrays on CPU. Perhaps we'll be able to speed that up soon, but I wanted to bounce this off you anyway. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated timings, being slightly less dumb:
time to grab first element, uncompiled: 0.01026153564453125 ms
time to grab first element, compiled: 0.009059906005859375 ms
total time to grab all elements, uncompiled: 42.66105651855469 ms
average time to grab each element, uncompiled: 0.004266105651855468 ms
total time to grab all elements, compiled: 42.526516914367676 ms
average time to grab each element, compiled: 0.004252651691436768 ms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New times, after jit dispatch time upgrades in jaxlib==0.1.65 (remember these are noisy!).
With special treatment of CPU (which is part of what this comment thread was about):
time to grab first element, uncompiled: 0.013675689697265625 ms
time to grab first element, compiled: 0.012180805206298828 ms
total time to grab all elements, uncompiled: 38.03548336029053 ms
average time to grab each element, uncompiled: 0.003803548336029053 ms
total time to grab all elements, compiled: 37.87814140319824 ms
average time to grab each element, compiled: 0.0037878141403198244 ms
Without special treatment of CPU:
time to grab first element, uncompiled: 0.013167858123779297 ms
time to grab first element, compiled: 0.012054443359375 ms
total time to grab all elements, uncompiled: 37.4117112159729 ms
average time to grab each element, uncompiled: 0.00374117112159729 ms
total time to grab all elements, compiled: 37.27412462234497 ms
average time to grab each element, compiled: 0.0037274124622344975 ms
So, no need for special treatment of CPU now!
I think this is worth merging? |
d8510c2
to
4f573c1
Compare
jax-ml/jax#3821. The idea of the JAX change is in part that DeviceArray.__iter__ should return DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is performance: it avoids a host sync. A secondary motivation is type consistency. However, that caused this line of Flax example code to trigger a NumPy bug, discussed in this thread: jax-ml/jax#620 (comment) Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a tuple) rather than as an array, leading to an error like "IndexError: too many indices for array". The workaround employed here is to write x[i, ...] instead of x[i], which bypasses the NumPy bug. PiperOrigin-RevId: 345140147
jax-ml/jax#3821. The idea of the JAX change is in part that DeviceArray.__iter__ should return DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is performance: it avoids a host sync. A secondary motivation is type consistency. However, that caused this line of Flax example code to trigger a NumPy bug, discussed in this thread: jax-ml/jax#620 (comment) Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a tuple) rather than as an array, leading to an error like "IndexError: too many indices for array". The workaround employed here is to write x[i, ...] instead of x[i], which bypasses the NumPy bug. PiperOrigin-RevId: 345140147
jax-ml/jax#3821. The idea of the JAX change is in part that DeviceArray.__iter__ should return DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is performance: it avoids a host sync. A secondary motivation is type consistency. However, that caused this line of Flax example code to trigger a NumPy bug, discussed in this thread: jax-ml/jax#620 (comment) Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a tuple) rather than as an array, leading to an error like "IndexError: too many indices for array". The workaround employed here is to write x[i, ...] instead of x[i], which bypasses the NumPy bug. PiperOrigin-RevId: 345160314
Before JAX's #3821, DeviceArray.__iter__ returned numpy.ndarray instances. As a result, calling np.testing.assert_equal would do an array equality check: https://github.com/numpy/numpy/blob/92ebe1e9a6aeb47a881a1226b08218175776f9ea/numpy/testing/_private/utils.py#L341 After #3821, DeviceArray.__iter__ returns DeviceArray instances, and though these expose a __array__ method, the numpy logic linked above doesn't trigger. Ultimately that leads to an error like "ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()". The fix here is to call np.testing.assert_array_equal directly, rather than to rely on the np.testing.assert_equal wrapper to call it. PiperOrigin-RevId: 345232915 Change-Id: I4506314c27c051330bace980c8dde34e91034a40
Closing in favor of #8043 |
cross-ref #1583 #3330
The current (before this PR) implementation of
DeviceArray.__iter__
fom #965 pulls the array back to the host and does the slicing on the resulting NumPy array, effectively:There are three issues we want to address here:
list(device_array)
shouldn't return a list ofnumpy.ndaray
s, and instead should return a list ofDeviceArray
s, regardless of how those are computed (this issue came up in List comprehensions and for-loops on xla.DeviceArrays return numpy.ndarrays #1583 and in some JAX user chats recently);key, subkey = random.split(key)
in op-by-op mode shouldn't incur device synchronization, but it does curently because the we effectively callrandom.split(key).__iter__()
(this issue came up in Unexpected / unintuitive behaviour of tuple(DeviceArray) #3330 and in JAX chats);list(device_array)
orlist(jnp.arange(10000))
to be slow to evaluate, where before Make DeviceArray.__iter__ and __reversed__ forward to _value. #965 there could be a several-second compilation time and relatively slow (~100ms) execution time.To address (1) we could either keep the bounce via NumPy and add some device_puts, or keep all the operations on the device. To address (2) we want to keep all the operations on the device. Currently that means we need to compile and execute XLA computations (rather than just doing something in the runtime).
To address (3), i.e. to keep things performant, most importantly we don't want to incur big compilation times, which basically means we don't want to compile computations with large output arity. We also don't want to dispatch lots of XLA computations, to keep execution time reasonable. We can balance the two by chunking things, so that for chunk size C and array size N we end up compiling O(1) computations, none having output arity more than C, and dispatching about 2 N / C computations (for each chunk, one to slice it out of the original array and one to explode it).
I timed the current and new implementation this way:
Notice we are not calling
block_until_ready
and instead just looking at dispatch times. The numbers that come out are pretty noisy.For the current implementation, on a TPUv3 internal colab we see this:
For the new implementation with chunk_size=100, on a TPUv3 internal colab we see this:
I think the ballparks here seem acceptable, and so meet desideratum (3) while avoiding the host sync per (2). Comparing to #965, we avoid any multi-second compile times even though we end up paying 80ms to compute
list(big)
rather than <1ms. I suspect these times will get better when we improvejit
dispatch time, since a significant fraction of the trace is spent on Python overhead (anything that's not an Execute bar):Here's one other benchmark:
With the current implementation, on TPU we get:
With the new implementation, on TPU we get:
So perhaps we're saving something from avoiding the sync here, even though there's no other work going on here.
fixes #1583