-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Control the lifetime of host values of the jax.Array. #26216
Comments
In addition to this bug, if there is an API that explicitly controls the lifetime of the np array created from jax.array, can you please the np. array also writable? |
Interestingly, this repro works for import jax
import numpy as np
from jax import numpy as jnp
def main():
d_tensor = jnp.array(0, dtype=jnp.bfloat16)
d_sharding = d_tensor.sharding
h_sharding = d_sharding.with_memory_kind("pinned_host")
h_tensor = jax.device_put(d_tensor, h_sharding)
def d2h_copy(src, dst):
dst = jax.device_put(src, h_sharding)
return src, dst
d2h_copy_jit = jax.jit(
d2h_copy,
in_shardings=(h_sharding, d_sharding),
out_shardings=(h_sharding, d_sharding),
donate_argnums=(0, 1),
keep_unused=True,
)
np.array(h_tensor)
# This line fixes it (or not using bfloat16 above).
# h_tensor._npy_value = np.array(-1, dtype=h_tensor.dtype)
h_tensor, d_tensor = d2h_copy_jit(h_tensor, d_tensor)
if __name__ == "__main__":
main() |
I believe if it raises here, the refcounts are off somehow: https://cs.opensource.google/tensorflow/tensorflow/+/master:third_party/xla/xla/python/py_array.cc;l=1530 So the |
This will not be possible in general. JAX arrays are immutable, and conversion to NumPy is done in a zero-copy manner whenever possible. Thus the resulting NumPy arrays will also be read-only: if you want to write into the buffer, you'll need to make a copy. |
Thanks for this insight! I think it's definitely possible that there's some fishy behavior around donation and cacheing of the host-side array. Assigning @dfm, but Dan feel free to bump to someone else if needed. |
It is technically possible, I can for instance do __array_interface__ = np.array(jax_array, copy=False).__array_interface__
ptr, _ = __array_interface__["data"]
__array_interface__["data"] = (ptr, False) and then give It would be better if JAX gave me a way to not require that hack though. Why would I want to do this? For instance in order to write into the pinned host memory at checkpoint loading time (for instance, using |
I mean, sure. It's Python, you can do whatever you want. What I'm saying is that we deliberately make the buffer read-only because to do otherwise would lead to bugs, because the JAX runtime assumes the buffers that back its arrays are immutable. I would not suggest using code that works around that in any important context. |
actually making array mutable is not a bad idea in case of saving large copies. |
If a buffer is OnCpu, then it is created with ZeroCopy code path. Then the created ndarray is basically something not too different from a shared_ptr. In that case, there is no need to cache. |
Description
Sample repro of this problem. Please read the code and probably it is better than my description.
The issue here is that the
_npy_value
holds the lifetime of the external reference of the device buffer even thenp.asarray
thing goes out of the scope. The lifetime of eachexternal reference hold
should be bound with the np.asarray it creates but not with the jax.Array.System info (python version, jaxlib version, accelerator, etc.)
HEAD
The text was updated successfully, but these errors were encountered: