Skip to content

assert jnp.ndarray trigger host to device copy on GPU/TPU? #10283

Answered by jakevdp
luweizheng asked this question in General
Discussion options

You must be logged in to vote

When you do assert xyz or if xyz, Python will call the __bool__ method of the object xyz.

In the case of two Device Arrays x and y, the statement z = x < y will return another Device Array of type bool. The __bool__ method for DeviceArrays is defined here: https://github.com/google/jax/blob/b3a62cd3f2be15a7ed23771b371835e2977961be/jax/_src/device_array.py#L269

This currently will forward z._value.__bool__, and z._value involves converting the buffer to a py buffer on the CPU: https://github.com/google/jax/blob/b3a62cd3f2be15a7ed23771b371835e2977961be/jax/_src/device_array.py#L144-L150

I don't think there's any way around this: Python control flow happens on the CPU, because Python runs on…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@luweizheng
Comment options

@jakevdp
Comment options

Answer selected by luweizheng
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants