assert jnp.ndarray trigger host to device copy on GPU/TPU? #10283
-
Hi all, I am on GPU platform. I am wondering how to do If assert return i = jnp.asarray(1)
steps_num = jnp.asarray(3)
# will it do memory copy from device to host?
assert i < steps_num
# if is a Python control flow, I guess it will trigger memory copy from deivce to host.
if i < steps_num:
print("True") Which way should I use to prevent memory copy and do |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
When you do In the case of two Device Arrays This currently will forward I don't think there's any way around this: Python control flow happens on the CPU, because Python runs on the CPU, so you cannot assert the value of something that doesn't live on the CPU. But this is a single byte being transferred in this case and only at trace-time (not at run-time), so I suspect it's not worth worrying about. |
Beta Was this translation helpful? Give feedback.
When you do
assert xyz
orif xyz
, Python will call the__bool__
method of the objectxyz
.In the case of two Device Arrays
x
andy
, the statementz = x < y
will return another Device Array of typebool
. The__bool__
method for DeviceArrays is defined here: https://github.com/google/jax/blob/b3a62cd3f2be15a7ed23771b371835e2977961be/jax/_src/device_array.py#L269This currently will forward
z._value.__bool__
, andz._value
involves converting the buffer to a py buffer on the CPU: https://github.com/google/jax/blob/b3a62cd3f2be15a7ed23771b371835e2977961be/jax/_src/device_array.py#L144-L150I don't think there's any way around this: Python control flow happens on the CPU, because Python runs on…