Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref_get mutates the sharding of the underlying buffer of a mutable array #26350

Open
ayaka14732 opened this issue Feb 6, 2025 · 0 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@ayaka14732
Copy link
Member

Description

import os
os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['XLA_FLAGS'] = os.environ.get('XLA_FLAGS', '') + ' --xla_force_host_platform_device_count=4'

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax._src.core import mutable_array
from jax._src.state.primitives import ref_get

devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))
sharding = NamedSharding(mesh, P('i', 'j'))

a = jnp.zeros_like(mesh.device_ids, dtype=jnp.int32)
a = jax.make_array_from_callback(a.shape, sharding, lambda idx: a[idx])

a_ref = mutable_array(a)
# ref_get(a_ref, ...)
print(a_ref._buf.sharding)

This prints:

NamedSharding(mesh=Mesh('i': 2, 'j': 2), spec=PartitionSpec('i', 'j'), memory_kind=unpinned_host)

But if we uncomment the line ref_get(a_ref, ...), it prints:

SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

Note that ref_get is purely a read operation that should not have any side effects.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.13.0rc3 (main, Oct  2 2024, 17:18:08) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ayx1', release='6.10.11-1rodete2-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.10.11-1rodete2 (2024-10-16)', machine='x86_64')
@ayaka14732 ayaka14732 added the bug Something isn't working label Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants