Skip to content

Commit

Permalink
Deprecate arr.device_buffer and arr.device_buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 29, 2023
1 parent 0fce77a commit d6aa9e3
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 111 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ Remember to align the itemized text with the first line of an item within a list
that cannot be converted to a JAX array is deprecated and now raises a
{obj}`DeprecationWaning`. Currently the functions return False, in the future this
will raise an exception.
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
Explicit buffers have been replaced by the more flexible array sharding interface,
but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`


## jaxlib 0.4.21
Expand Down
8 changes: 4 additions & 4 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu"
platforms are available in priority order).

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device_buffer.device()) # doctest: +SKIP
gpu:0
>>> print(jnp.ones(3).device()) # doctest: +SKIP
cuda:0

Computations involving uncommitted data are performed on the default
device and the results are uncommitted on the default device.
Expand All @@ -355,8 +355,8 @@ with a ``device`` parameter, in which case the data becomes **committed** to the

>>> import jax
>>> from jax import device_put
>>> print(device_put(1, jax.devices()[2]).device_buffer.device()) # doctest: +SKIP
gpu:2
>>> print(device_put(1, jax.devices()[2]).device()) # doctest: +SKIP
cuda:2

Computations involving some committed inputs will happen on the
committed device and the result will be committed on the
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
buffers = []
for leaf in tree_leaves(output):
if hasattr(leaf, "device_buffers"):
buffers.extend(leaf.device_buffers)
if hasattr(leaf, "addressable_shards"):
buffers.extend([shard.data for shard in leaf.addressable_shards])

try:
dispatch.check_special(pjit.pjit_p.name, buffers)
Expand Down
9 changes: 9 additions & 0 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import functools
from typing import Any, Callable, cast, TYPE_CHECKING
import warnings
from collections.abc import Sequence

from jax._src import abstract_arrays
Expand Down Expand Up @@ -466,6 +467,10 @@ def devices(self) -> set[Device]:
# deleted.
@property
def device_buffer(self) -> ArrayImpl:
# Added 2023 Nov 29
warnings.warn(
"arr.device_buffer is deprecated. Use arr.addressable_data(0)",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
if len(self._arrays) == 1:
return self._arrays[0]
Expand All @@ -476,6 +481,10 @@ def device_buffer(self) -> ArrayImpl:
# deleted.
@property
def device_buffers(self) -> Sequence[ArrayImpl]:
# Added 2023 Nov 29
warnings.warn(
"arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
return cast(Sequence[ArrayImpl], self._arrays)

Expand Down
8 changes: 4 additions & 4 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,11 +1118,11 @@ class BufferDonationTestCase(JaxTestCase):
def _assertDeleted(self, x, deleted):
if hasattr(x, "_arrays"):
self.assertEqual(x.is_deleted(), deleted)
elif hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
elif hasattr(x, "is_deleted"):
self.assertEqual(x.is_deleted(), deleted)
else:
for buffer in x.device_buffers:
self.assertEqual(buffer.is_deleted(), deleted)
for shard in x.addressable_shards:
self.assertEqual(shard.data.is_deleted(), deleted)


@contextmanager
Expand Down
2 changes: 2 additions & 0 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ def test_array_jnp_array_copy_multi_device(self):
self.assertNotEqual(a.data.unsafe_buffer_pointer(),
c.data.unsafe_buffer_pointer())

@jtu.ignore_warning(category=DeprecationWarning,
message="arr.device_buffers? is deprecated")
def test_array_device_buffer(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
Expand Down
Loading

0 comments on commit d6aa9e3

Please sign in to comment.