From d6aa9e30bdee904057d59577f534483e8c0076fc Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 29 Nov 2023 10:22:10 -0800 Subject: [PATCH] Deprecate arr.device_buffer and arr.device_buffers --- CHANGELOG.md | 5 + docs/faq.rst | 8 +- jax/_src/api.py | 4 +- jax/_src/array.py | 9 ++ jax/_src/test_util.py | 8 +- tests/array_test.py | 2 + tests/notebooks/colab_cpu.ipynb | 198 ++++++++++++++++---------------- tests/pjit_test.py | 2 + tests/shard_map_test.py | 2 + tests/xmap_test.py | 4 +- 10 files changed, 131 insertions(+), 111 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a1cdfd4c532..b32215415d54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,11 @@ Remember to align the itemized text with the first line of an item within a list that cannot be converted to a JAX array is deprecated and now raises a {obj}`DeprecationWaning`. Currently the functions return False, in the future this will raise an exception. + * The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated. + Explicit buffers have been replaced by the more flexible array sharding interface, + but the previous outputs can be recovered this way: + * `arr.device_buffer` becomes `arr.addressable_data(0)` + * `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]` ## jaxlib 0.4.21 diff --git a/docs/faq.rst b/docs/faq.rst index 441af2436852..a2db177439ab 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -344,8 +344,8 @@ or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu" platforms are available in priority order). >>> from jax import numpy as jnp ->>> print(jnp.ones(3).device_buffer.device()) # doctest: +SKIP -gpu:0 +>>> print(jnp.ones(3).device()) # doctest: +SKIP +cuda:0 Computations involving uncommitted data are performed on the default device and the results are uncommitted on the default device. @@ -355,8 +355,8 @@ with a ``device`` parameter, in which case the data becomes **committed** to the >>> import jax >>> from jax import device_put ->>> print(device_put(1, jax.devices()[2]).device_buffer.device()) # doctest: +SKIP -gpu:2 +>>> print(device_put(1, jax.devices()[2]).device()) # doctest: +SKIP +cuda:2 Computations involving some committed inputs will happen on the committed device and the result will be committed on the diff --git a/jax/_src/api.py b/jax/_src/api.py index cb2761955dff..bcabf0b36d21 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -102,8 +102,8 @@ def _nan_check_posthook(fun, args, kwargs, output): """Hook function called by the C++ jit/pmap to perform NaN checking.""" buffers = [] for leaf in tree_leaves(output): - if hasattr(leaf, "device_buffers"): - buffers.extend(leaf.device_buffers) + if hasattr(leaf, "addressable_shards"): + buffers.extend([shard.data for shard in leaf.addressable_shards]) try: dispatch.check_special(pjit.pjit_p.name, buffers) diff --git a/jax/_src/array.py b/jax/_src/array.py index f0c77c4f7104..bd12a7f9fb45 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -21,6 +21,7 @@ import numpy as np import functools from typing import Any, Callable, cast, TYPE_CHECKING +import warnings from collections.abc import Sequence from jax._src import abstract_arrays @@ -466,6 +467,10 @@ def devices(self) -> set[Device]: # deleted. @property def device_buffer(self) -> ArrayImpl: + # Added 2023 Nov 29 + warnings.warn( + "arr.device_buffer is deprecated. Use arr.addressable_data(0)", + DeprecationWarning, stacklevel=2) self._check_if_deleted() if len(self._arrays) == 1: return self._arrays[0] @@ -476,6 +481,10 @@ def device_buffer(self) -> ArrayImpl: # deleted. @property def device_buffers(self) -> Sequence[ArrayImpl]: + # Added 2023 Nov 29 + warnings.warn( + "arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]", + DeprecationWarning, stacklevel=2) self._check_if_deleted() return cast(Sequence[ArrayImpl], self._arrays) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 860b8c781e25..ab67ddd7b726 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1118,11 +1118,11 @@ class BufferDonationTestCase(JaxTestCase): def _assertDeleted(self, x, deleted): if hasattr(x, "_arrays"): self.assertEqual(x.is_deleted(), deleted) - elif hasattr(x, "device_buffer"): - self.assertEqual(x.device_buffer.is_deleted(), deleted) + elif hasattr(x, "is_deleted"): + self.assertEqual(x.is_deleted(), deleted) else: - for buffer in x.device_buffers: - self.assertEqual(buffer.is_deleted(), deleted) + for shard in x.addressable_shards: + self.assertEqual(shard.data.is_deleted(), deleted) @contextmanager diff --git a/tests/array_test.py b/tests/array_test.py index dee62266bd79..05dfb392be11 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -583,6 +583,8 @@ def test_array_jnp_array_copy_multi_device(self): self.assertNotEqual(a.data.unsafe_buffer_pointer(), c.data.unsafe_buffer_pointer()) + @jtu.ignore_warning(category=DeprecationWarning, + message="arr.device_buffers? is deprecated") def test_array_device_buffer(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index 089d5667f326..7d484b964da4 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -1,23 +1,10 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "JAX Colab CPU Test", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" @@ -26,8 +13,8 @@ { "cell_type": "markdown", "metadata": { - "id": "WkadOyTDCAWD", - "colab_type": "text" + "colab_type": "text", + "id": "WkadOyTDCAWD" }, "source": [ "# JAX Colab CPU Test\n", @@ -37,41 +24,41 @@ }, { "cell_type": "code", + "execution_count": 6, "metadata": { - "id": "_tKNrbqqBHwu", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, + "colab_type": "code", + "id": "_tKNrbqqBHwu", "outputId": "071fb360-ddf5-41ae-d772-acc08ec71d9b" }, - "source": [ - "import jax\n", - "import jaxlib\n", - "\n", - "!cat /var/colab/hostname\n", - "print(jax.__version__)\n", - "print(jaxlib.__version__)" - ], - "execution_count": 6, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "m-s-1p12yf76kgzz\n", "0.1.64\n", "0.1.45\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "import jax\n", + "import jaxlib\n", + "\n", + "!cat /var/colab/hostname\n", + "print(jax.__version__)\n", + "print(jaxlib.__version__)" ] }, { "cell_type": "markdown", "metadata": { - "id": "oqEG21rADO1F", - "colab_type": "text" + "colab_type": "text", + "id": "oqEG21rADO1F" }, "source": [ "## Confirm Device" @@ -79,48 +66,48 @@ }, { "cell_type": "code", + "execution_count": 2, "metadata": { - "colab_type": "code", - "id": "8BwzMYhKGQj6", - "outputId": "f79a44e3-4303-494c-9288-a4e582bb34cb", "colab": { "base_uri": "https://localhost:8080/", "height": 68 - } + }, + "colab_type": "code", + "id": "8BwzMYhKGQj6", + "outputId": "f79a44e3-4303-494c-9288-a4e582bb34cb" }, - "source": [ - "from jaxlib import xla_extension\n", - "import jax\n", - "key = jax.random.PRNGKey(1701)\n", - "arr = jax.random.normal(key, (1000,))\n", - "device = arr.device_buffer.device()\n", - "print(f\"JAX device type: {device}\")\n", - "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\"" - ], - "execution_count": 2, "outputs": [ { + "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.\n", " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" - ], - "name": "stderr" + ] }, { + "name": "stdout", "output_type": "stream", "text": [ "JAX device type: cpu:0\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "from jaxlib import xla_extension\n", + "import jax\n", + "key = jax.random.PRNGKey(1701)\n", + "arr = jax.random.normal(key, (1000,))\n", + "device = arr.device()\n", + "print(f\"JAX device type: {device}\")\n", + "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\"" ] }, { "cell_type": "markdown", "metadata": { - "id": "z0FUY9yUC4k1", - "colab_type": "text" + "colab_type": "text", + "id": "z0FUY9yUC4k1" }, "source": [ "## Matrix Multiplication" @@ -128,15 +115,25 @@ }, { "cell_type": "code", + "execution_count": 3, "metadata": { - "colab_type": "code", - "id": "eXn8GUl6CG5N", - "outputId": "307aa669-76f1-4117-b62a-7acb2aee2c16", "colab": { "base_uri": "https://localhost:8080/", "height": 34 - } + }, + "colab_type": "code", + "id": "eXn8GUl6CG5N", + "outputId": "307aa669-76f1-4117-b62a-7acb2aee2c16" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0216691\n" + ] + } + ], "source": [ "import jax\n", "import numpy as np\n", @@ -146,23 +143,13 @@ "x = jax.random.normal(key, (3000, 3000))\n", "result = jax.numpy.dot(x, x.T).mean()\n", "print(result)" - ], - "execution_count": 3, - "outputs": [ - { - "output_type": "stream", - "text": [ - "1.0216691\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "0zTA2Q19DW4G", - "colab_type": "text" + "colab_type": "text", + "id": "0zTA2Q19DW4G" }, "source": [ "## Linear Algebra" @@ -170,15 +157,26 @@ }, { "cell_type": "code", + "execution_count": 4, "metadata": { - "id": "uW9j84_UDYof", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, + "colab_type": "code", + "id": "uW9j84_UDYof", "outputId": "3dd5d7c0-9d47-4be1-c6f7-068b432b69f7" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[6.9178133 5.9580317 5.581113 4.506963 4.111582 3.973543 3.3307292\n", + " 2.8664916 1.8229378 1.5478933]\n" + ] + } + ], "source": [ "import jax.numpy as jnp\n", "import jax.random as rand\n", @@ -192,24 +190,13 @@ "assert u.shape == (N, N)\n", "assert vt.shape == (M, M)\n", "print(s)" - ], - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "text": [ - "[6.9178133 5.9580317 5.581113 4.506963 4.111582 3.973543 3.3307292\n", - " 2.8664916 1.8229378 1.5478933]\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "jCyKUn4-DCXn", - "colab_type": "text" + "colab_type": "text", + "id": "jCyKUn4-DCXn" }, "source": [ "## XLA Compilation" @@ -217,34 +204,47 @@ }, { "cell_type": "code", + "execution_count": 5, "metadata": { - "colab_type": "code", - "id": "2GOn_HhDPuEn", - "outputId": "41a40dd9-3680-458d-cedd-81ebcc2ab06f", "colab": { "base_uri": "https://localhost:8080/", "height": 51 - } + }, + "colab_type": "code", + "id": "2GOn_HhDPuEn", + "outputId": "41a40dd9-3680-458d-cedd-81ebcc2ab06f" }, - "source": [ - "@jax.jit\n", - "def selu(x, alpha=1.67, lmbda=1.05):\n", - " return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n", - "x = jax.random.normal(key, (5000,))\n", - "result = selu(x).block_until_ready()\n", - "print(result)" - ], - "execution_count": 5, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "[ 0.34676832 -0.7532232 1.7060695 ... 2.1208048 -0.42621925\n", " 0.13093236]\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "@jax.jit\n", + "def selu(x, alpha=1.67, lmbda=1.05):\n", + " return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n", + "x = jax.random.normal(key, (5000,))\n", + "result = selu(x).block_until_ready()\n", + "print(result)" ] } - ] + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "JAX Colab CPU Test", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 602eb6a5f784..08be6d3ec3a8 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -125,6 +125,8 @@ def check_1d_2d_mesh(f, set_mesh): # TODO(skye): make the buffer donation utils part of JaxTestCase +@jtu.ignore_warning(category=DeprecationWarning, + message="arr.device_buffers? is deprecated") @jtu.pytest_mark_if_available('multiaccelerator') class PJitTest(jtu.BufferDonationTestCase): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 9e9f10469422..166ca55bd6f9 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -90,6 +90,8 @@ def tearDownModule(): xla_bridge.get_backend.cache_clear() +@jtu.ignore_warning(category=DeprecationWarning, + message="arr.device_buffers? is deprecated") class ShardMapTest(jtu.JaxTestCase): def test_identity(self): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 851ed2ef6849..d91d05a230a2 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -929,7 +929,7 @@ def testAllGather(self, mesh): in_axes=['i', None], out_axes=[None], axis_resources={'i': 'x'}) h = pjit(f, in_shardings=P('x', None), out_shardings=P(None))(x) - assert (h.device_buffers[0] == x.reshape(8)).all() + assert (h.addressable_data(0) == x.reshape(8)).all() @parameterized.named_parameters( {'testcase_name': name, 'mesh': mesh} @@ -949,7 +949,7 @@ def testReduceScatter(self, mesh): out_shardings=P('x', None), )(x) - assert (h.device_buffers[0].reshape(4) == x[0, :]*2).all() + assert (h.addressable_data(0).reshape(4) == x[0, :]*2).all() @jtu.with_mesh([('x', 2)]) def testBareXmapCollective(self):