diff --git a/CHANGELOG.md b/CHANGELOG.md index a50a66aeb906..a1c7df9df882 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.22 +* Deprecations + * The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated. + Explicit buffers have been replaced by the more flexible array sharding interface, + but the previous outputs can be recovered this way: + * `arr.device_buffer` becomes `arr.addressable_data(0)` + * `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]` + ## jaxlib 0.4.22 ## jax 0.4.21 (Dec 4 2023) diff --git a/jax/_src/api.py b/jax/_src/api.py index 3ec38795ab63..616036d2188a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -102,8 +102,8 @@ def _nan_check_posthook(fun, args, kwargs, output): """Hook function called by the C++ jit/pmap to perform NaN checking.""" buffers = [] for leaf in tree_leaves(output): - if hasattr(leaf, "device_buffers"): - buffers.extend(leaf.device_buffers) + if hasattr(leaf, "addressable_shards"): + buffers.extend([shard.data for shard in leaf.addressable_shards]) try: dispatch.check_special(pjit.pjit_p.name, buffers) diff --git a/jax/_src/array.py b/jax/_src/array.py index 2cf89e5a6560..95a69c5f3c15 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -474,6 +474,10 @@ def devices(self) -> set[Device]: # deleted. @property def device_buffer(self) -> ArrayImpl: + # Added 2023 Dec 6 + warnings.warn( + "arr.device_buffer is deprecated. Use arr.addressable_data(0)", + DeprecationWarning, stacklevel=2) self._check_if_deleted() if len(self._arrays) == 1: return self._arrays[0] @@ -484,6 +488,10 @@ def device_buffer(self) -> ArrayImpl: # deleted. @property def device_buffers(self) -> Sequence[ArrayImpl]: + # Added 2023 Dec 6 + warnings.warn( + "arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]", + DeprecationWarning, stacklevel=2) self._check_if_deleted() return cast(Sequence[ArrayImpl], self._arrays) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 8bcf49f56ca6..821d61e575e3 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1112,17 +1112,11 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, ) class BufferDonationTestCase(JaxTestCase): - assertDeleted = lambda self, x: self._assertDeleted(x, True) - assertNotDeleted = lambda self, x: self._assertDeleted(x, False) - - def _assertDeleted(self, x, deleted): - if hasattr(x, "_arrays"): - self.assertEqual(x.is_deleted(), deleted) - elif hasattr(x, "device_buffer"): - self.assertEqual(x.device_buffer.is_deleted(), deleted) - else: - for buffer in x.device_buffers: - self.assertEqual(buffer.is_deleted(), deleted) + def assertDeleted(self, x): + self.assertTrue(x.is_deleted()) + + def assertNotDeleted(self, x): + self.assertFalse(x.is_deleted()) @contextmanager diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index dbf6085be0de..1540b3d20892 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -93,7 +93,7 @@ "import jax\n", "key = jax.random.PRNGKey(1701)\n", "arr = jax.random.normal(key, (1000,))\n", - "device = arr.device_buffer.device()\n", + "device = arr.device()\n", "print(f\"JAX device type: {device}\")\n", "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\"" ], diff --git a/tests/xmap_test.py b/tests/xmap_test.py index d91d05a230a2..d93f2c5a59f9 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -929,7 +929,7 @@ def testAllGather(self, mesh): in_axes=['i', None], out_axes=[None], axis_resources={'i': 'x'}) h = pjit(f, in_shardings=P('x', None), out_shardings=P(None))(x) - assert (h.addressable_data(0) == x.reshape(8)).all() + self.assertArraysEqual(h.addressable_data(0), x.reshape(8)) @parameterized.named_parameters( {'testcase_name': name, 'mesh': mesh} @@ -949,7 +949,7 @@ def testReduceScatter(self, mesh): out_shardings=P('x', None), )(x) - assert (h.addressable_data(0).reshape(4) == x[0, :]*2).all() + self.assertArraysEqual(h.addressable_data(0).reshape(4), x[0, :] * 2) @jtu.with_mesh([('x', 2)]) def testBareXmapCollective(self):