Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control the lifetime of host values of the jax.Array. #26216

Open
yliu120 opened this issue Jan 30, 2025 · 10 comments
Open

Control the lifetime of host values of the jax.Array. #26216

yliu120 opened this issue Jan 30, 2025 · 10 comments
Assignees
Labels
bug Something isn't working

Comments

@yliu120
Copy link
Contributor

yliu120 commented Jan 30, 2025

Description

Sample repro of this problem. Please read the code and probably it is better than my description.

import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

import jax
import jax.numpy as jnp

from jax.sharding import SingleDeviceSharding
from jax.sharding import PartitionSpec as P
from jax.experimental import mesh_utils

import numpy as np
from ml_dtypes import bfloat16


def main():
    s = SingleDeviceSharding(jax.local_devices()[0], memory_kind="pinned_host")
    s_dev = s.with_memory_kind("device")

    # Allocates pinned memory by numpy array.
    jax.profiler.start_trace("/root/tensorboard/pinned_host")
    np_inp = np.ones((8192, 32768), dtype=bfloat16)
    weight_on_pinned_host = jax.device_put(np_inp, s)

    weight_on_device = jax.random.normal(jax.random.PRNGKey(0), (8192, 32768), dtype=jnp.bfloat16)
    weight = jax.device_put(weight_on_device, s_dev)

    def copy_weight_to_pinned_host(w, w_on_host_donate):
        w_on_host = jax.device_put(w, s)
        return w, w_on_host

    copy_jitted = jax.jit(
        copy_weight_to_pinned_host,
        in_shardings=(s_dev, s),
        out_shardings=(s_dev, s),
        donate_argnums=(0, 1),
        keep_unused=True,
    )

    def save_host_array(arr):
        a = np.asarray(arr)
        np.save('test.npy', a)

    for i in range(1000):
        print(f"iteration i: {i}")
        weight, weight_on_pinned_host = copy_jitted(weight, weight_on_pinned_host)
        save_host_array(weight_on_pinned_host)
        # If you don't use this code, it will crash.
        # weight_on_pinned_host._npy_value = np.array(0)
        # np.testing.assert_allclose(weight, weight_on_pinned_host)
    weight_on_pinned_host.block_until_ready()
    jax.profiler.stop_trace()

The issue here is that the _npy_value holds the lifetime of the external reference of the device buffer even the np.asarray thing goes out of the scope. The lifetime of each external reference hold should be bound with the np.asarray it creates but not with the jax.Array.

System info (python version, jaxlib version, accelerator, etc.)

HEAD

@yliu120 yliu120 added the bug Something isn't working label Jan 30, 2025
@yliu120
Copy link
Contributor Author

yliu120 commented Jan 30, 2025

@hawkinsp

@yliu120
Copy link
Contributor Author

yliu120 commented Jan 30, 2025

In addition to this bug,

if there is an API that explicitly controls the lifetime of the np array created from jax.array, can you please the np. array also writable?
This is a natural request as if the users have control over the lifetime, then they can make sure writing to the buffer can be done before the next execute call.

@heiner
Copy link

heiner commented Jan 30, 2025

Interestingly, this repro works for float16 but shows the same issue for bfloat16:

import jax
import numpy as np
from jax import numpy as jnp


def main():
    d_tensor = jnp.array(0, dtype=jnp.bfloat16)

    d_sharding = d_tensor.sharding
    h_sharding = d_sharding.with_memory_kind("pinned_host")

    h_tensor = jax.device_put(d_tensor, h_sharding)

    def d2h_copy(src, dst):
        dst = jax.device_put(src, h_sharding)
        return src, dst

    d2h_copy_jit = jax.jit(
        d2h_copy,
        in_shardings=(h_sharding, d_sharding),
        out_shardings=(h_sharding, d_sharding),
        donate_argnums=(0, 1),
        keep_unused=True,
    )

    np.array(h_tensor)

    # This line fixes it (or not using bfloat16 above).
    # h_tensor._npy_value = np.array(-1, dtype=h_tensor.dtype)

    h_tensor, d_tensor = d2h_copy_jit(h_tensor, d_tensor)


if __name__ == "__main__":
    main()

@heiner
Copy link

heiner commented Jan 30, 2025

I believe if it raises here, the refcounts are off somehow: https://cs.opensource.google/tensorflow/tensorflow/+/master:third_party/xla/xla/python/py_array.cc;l=1530

So the __array__ codepath causes trouble when donating.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 30, 2025

can you please the np. array also writable?

This will not be possible in general. JAX arrays are immutable, and conversion to NumPy is done in a zero-copy manner whenever possible. Thus the resulting NumPy arrays will also be read-only: if you want to write into the buffer, you'll need to make a copy.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 30, 2025

So the __array__ codepath causes trouble when donating.

Thanks for this insight! I think it's definitely possible that there's some fishy behavior around donation and cacheing of the host-side array.

Assigning @dfm, but Dan feel free to bump to someone else if needed.

@heiner
Copy link

heiner commented Jan 30, 2025

This will not be possible in general. JAX arrays are immutable, and conversion to NumPy is done in a zero-copy manner whenever possible. Thus the resulting NumPy arrays will also be read-only: if you want to write into the buffer, you'll need to make a copy.

It is technically possible, I can for instance do

__array_interface__ = np.array(jax_array, copy=False).__array_interface__
ptr, _ = __array_interface__["data"]
__array_interface__["data"] = (ptr, False)

and then give __array_interface__ back to Numpy.

It would be better if JAX gave me a way to not require that hack though. Why would I want to do this? For instance in order to write into the pinned host memory at checkpoint loading time (for instance, using ts.array(a).write(ts_data)). It would be better if JAX had a less hacky way to do this, as without that pinned host memory is far less useful.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 30, 2025

It is technically possible, I can for instance do

I mean, sure. It's Python, you can do whatever you want. What I'm saying is that we deliberately make the buffer read-only because to do otherwise would lead to bugs, because the JAX runtime assumes the buffers that back its arrays are immutable. I would not suggest using code that works around that in any important context.

@yliu120
Copy link
Contributor Author

yliu120 commented Jan 30, 2025

can you please the np. array also writable?

This will not be possible in general. JAX arrays are immutable, and conversion to NumPy is done in a zero-copy manner whenever possible. Thus the resulting NumPy arrays will also be read-only: if you want to write into the buffer, you'll need to make a copy.

actually making array mutable is not a bad idea in case of saving large copies.
and jax does have some mutable array proposals.

@yliu120
Copy link
Contributor Author

yliu120 commented Jan 30, 2025

So the __array__ codepath causes trouble when donating.

Thanks for this insight! I think it's definitely possible that there's some fishy behavior around donation and cacheing of the host-side array.

Assigning @dfm, but Dan feel free to bump to someone else if needed.

If a buffer is OnCpu, then it is created with ZeroCopy code path. Then the created ndarray is basically something not too different from a shared_ptr. In that case, there is no need to cache.
I think the fix should include this consideration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants