Skip to content

Commit

Permalink
Deprecate arr.device_buffer and arr.device_buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 6, 2023
1 parent 4bdcb11 commit 35b8440
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 16 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.22

* Deprecations
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
Explicit buffers have been replaced by the more flexible array sharding interface,
but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`

## jaxlib 0.4.22

## jax 0.4.21 (Dec 4 2023)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
buffers = []
for leaf in tree_leaves(output):
if hasattr(leaf, "device_buffers"):
buffers.extend(leaf.device_buffers)
if hasattr(leaf, "addressable_shards"):
buffers.extend([shard.data for shard in leaf.addressable_shards])

try:
dispatch.check_special(pjit.pjit_p.name, buffers)
Expand Down
8 changes: 8 additions & 0 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,10 @@ def devices(self) -> set[Device]:
# deleted.
@property
def device_buffer(self) -> ArrayImpl:
# Added 2023 Dec 6
warnings.warn(
"arr.device_buffer is deprecated. Use arr.addressable_data(0)",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
if len(self._arrays) == 1:
return self._arrays[0]
Expand All @@ -484,6 +488,10 @@ def device_buffer(self) -> ArrayImpl:
# deleted.
@property
def device_buffers(self) -> Sequence[ArrayImpl]:
# Added 2023 Dec 6
warnings.warn(
"arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
return cast(Sequence[ArrayImpl], self._arrays)

Expand Down
16 changes: 5 additions & 11 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,17 +1112,11 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,
)

class BufferDonationTestCase(JaxTestCase):
assertDeleted = lambda self, x: self._assertDeleted(x, True)
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)

def _assertDeleted(self, x, deleted):
if hasattr(x, "_arrays"):
self.assertEqual(x.is_deleted(), deleted)
elif hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
else:
for buffer in x.device_buffers:
self.assertEqual(buffer.is_deleted(), deleted)
def assertDeleted(self, x):
self.assertTrue(x.is_deleted())

def assertNotDeleted(self, x):
self.assertFalse(x.is_deleted())


@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion tests/notebooks/colab_cpu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = arr.device_buffer.device()\n",
"device = arr.device()\n",
"print(f\"JAX device type: {device}\")\n",
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
],
Expand Down
4 changes: 2 additions & 2 deletions tests/xmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ def testAllGather(self, mesh):
in_axes=['i', None], out_axes=[None],
axis_resources={'i': 'x'})
h = pjit(f, in_shardings=P('x', None), out_shardings=P(None))(x)
assert (h.addressable_data(0) == x.reshape(8)).all()
self.assertArraysEqual(h.addressable_data(0), x.reshape(8))

@parameterized.named_parameters(
{'testcase_name': name, 'mesh': mesh}
Expand All @@ -949,7 +949,7 @@ def testReduceScatter(self, mesh):
out_shardings=P('x', None),
)(x)

assert (h.addressable_data(0).reshape(4) == x[0, :]*2).all()
self.assertArraysEqual(h.addressable_data(0).reshape(4), x[0, :] * 2)

@jtu.with_mesh([('x', 2)])
def testBareXmapCollective(self):
Expand Down

0 comments on commit 35b8440

Please sign in to comment.