From e987062f8bfc7df1eb53494617bb73b2ed26c14e Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 May 2022 20:49:46 -0700 Subject: [PATCH 1/7] Use `bytearray`s in `ensure_memoryview` --- distributed/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index 4e7fe43dc43..3f4f66f7bf3 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1021,13 +1021,13 @@ def ensure_memoryview(obj): if not mv.nbytes: # Drop `obj` reference to permit freeing underlying data - return memoryview(b"") + return memoryview(bytearray()) elif mv.contiguous: # Perform zero-copy reshape & cast return mv.cast("B") else: # Copy to contiguous form of expected shape & type - return memoryview(mv.tobytes()) + return memoryview(bytearray(mv)) def open_port(host=""): From 969e8b48dda07147f58d10d03c1a188ba90af6ad Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 May 2022 20:49:47 -0700 Subject: [PATCH 2/7] First check cast that can't `cast` --- distributed/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index 3f4f66f7bf3..d10625cea84 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1022,12 +1022,12 @@ def ensure_memoryview(obj): if not mv.nbytes: # Drop `obj` reference to permit freeing underlying data return memoryview(bytearray()) - elif mv.contiguous: - # Perform zero-copy reshape & cast - return mv.cast("B") - else: + elif not mv.contiguous: # Copy to contiguous form of expected shape & type return memoryview(bytearray(mv)) + else: + # Perform zero-copy reshape & cast + return mv.cast("B") def open_port(host=""): From ffb885ce45afee25ba012aacb4b0f3161ebce26a Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 May 2022 20:49:48 -0700 Subject: [PATCH 3/7] Fastpath `ensure_memoryview` when no change needed --- distributed/tests/test_utils.py | 2 ++ distributed/utils.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 7e076c0ecb6..66b53f24fd3 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -260,6 +260,8 @@ def test_ensure_memoryview(): result = ensure_memoryview(d) assert isinstance(result, memoryview) assert result == memoryview(b"1") + if isinstance(d, memoryview): + assert id(d) == id(result) def test_ensure_memoryview_ndarray(): diff --git a/distributed/utils.py b/distributed/utils.py index d10625cea84..8b3c933b829 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1025,9 +1025,12 @@ def ensure_memoryview(obj): elif not mv.contiguous: # Copy to contiguous form of expected shape & type return memoryview(bytearray(mv)) - else: + elif mv.ndim != 1 or mv.format != "B": # Perform zero-copy reshape & cast return mv.cast("B") + else: + # Return `memoryview` as it already meets requirements + return mv def open_port(host=""): From 24e5e5f21246422619da03357af0e8972ddc221e Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 May 2022 20:49:48 -0700 Subject: [PATCH 4/7] Use `PickleBuffer.raw(...)` in `ensure_memoryview` As `memoryview.cast(...)` doesn't know how to handle F-order, go through `PickleBuffer.raw(...)`, which can handle this case. Also pick out the underlying `.obj` of the `memoryview` to ensure the `memoryview` that `PickleBuffer.raw(...)` creates points back at the original object and not the `memoryview` we created. --- distributed/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/distributed/utils.py b/distributed/utils.py index 8b3c933b829..e24e45b0c86 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -27,6 +27,7 @@ from functools import wraps from hashlib import md5 from importlib.util import cache_from_source +from pickle import PickleBuffer from time import sleep from types import ModuleType from typing import TYPE_CHECKING @@ -1027,7 +1028,10 @@ def ensure_memoryview(obj): return memoryview(bytearray(mv)) elif mv.ndim != 1 or mv.format != "B": # Perform zero-copy reshape & cast - return mv.cast("B") + # Use `PickleBuffer.raw()` as `memoryview.cast()` fails with F-order + # Pass `mv.obj` so the created `memoryview` has that as its `obj` + # xref: https://github.com/python/cpython/issues/91484 + return PickleBuffer(mv.obj).raw() else: # Return `memoryview` as it already meets requirements return mv From 4d856791da3a146d7d5540e8cb98ee741aa015a7 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 May 2022 20:49:49 -0700 Subject: [PATCH 5/7] Cover more cases in `ensure_memoryview` tests Handle different types, shapes, striding, etc. --- distributed/tests/test_utils.py | 47 ++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 66b53f24fd3..59860538ef4 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -248,20 +248,41 @@ def test_seek_delimiter_endline(): assert f.tell() == 7 -def test_ensure_memoryview_empty(): - result = ensure_memoryview(b"") +@pytest.mark.parametrize( + "data", + [ + b"", + bytearray(), + b"1", + bytearray(b"1"), + memoryview(b"1"), + memoryview(bytearray(b"1")), + array("B", b"1"), + array("I", range(5)), + memoryview(b"123456")[::2], + memoryview(b"123456").cast("B", (2, 3)), + memoryview(b"0123456789").cast("B", (5, 2))[::2], + ], +) +def test_ensure_memoryview(data): + data_mv = memoryview(data) + result = ensure_memoryview(data) assert isinstance(result, memoryview) - assert result == memoryview(b"") - - -def test_ensure_memoryview(): - data = [b"1", memoryview(b"1"), bytearray(b"1"), array("B", b"1")] - for d in data: - result = ensure_memoryview(d) - assert isinstance(result, memoryview) - assert result == memoryview(b"1") - if isinstance(d, memoryview): - assert id(d) == id(result) + assert result.contiguous + assert result.ndim == 1 + assert result.format == "B" + assert result == bytes(data_mv) + if data_mv.nbytes and data_mv.contiguous: + assert id(result.obj) == id(data_mv.obj) + assert result.readonly == data_mv.readonly + if isinstance(data, memoryview): + if data.ndim == 1 and data.format == "B": + assert id(result) == id(data) + else: + assert id(data) != id(result) + else: + assert id(result.obj) != id(data_mv.obj) + assert not result.readonly def test_ensure_memoryview_ndarray(): From 90c1c2692b3230d4bfd1e73a5fcaaa9a685cc834 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 May 2022 20:49:49 -0700 Subject: [PATCH 6/7] Improve `ensure_memoryview` NumPy testing Try using `ensure_memoryview` with a variety of different NumPy `ndarray` types to improve test coverage. --- distributed/tests/test_utils.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 59860538ef4..961bd6a9164 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -285,9 +285,22 @@ def test_ensure_memoryview(data): assert not result.readonly -def test_ensure_memoryview_ndarray(): +@pytest.mark.parametrize( + "dt, nitems, shape, strides", + [ + ("i8", 12, (12,), (8,)), + ("i8", 12, (3, 4), (32, 8)), + ("i8", 12, (4, 3), (8, 32)), + ("i8", 12, (3, 2), (32, 16)), + ("i8", 12, (2, 3), (16, 32)), + ], +) +def test_ensure_memoryview_ndarray(dt, nitems, shape, strides): np = pytest.importorskip("numpy") - result = ensure_memoryview(np.arange(12).reshape(3, 4)[:, ::2].T) + data = np.ndarray( + shape, dtype=dt, buffer=np.arange(nitems, dtype=dt), strides=strides + ) + result = ensure_memoryview(data) assert isinstance(result, memoryview) assert result.ndim == 1 assert result.format == "B" From 96e23187b6c6aac38cd88ebafab8785ca730095d Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 May 2022 20:49:50 -0700 Subject: [PATCH 7/7] Use `ensure_memoryview` in a few more places --- distributed/comm/asyncio_tcp.py | 15 +++++++-------- distributed/comm/tcp.py | 4 ++-- distributed/protocol/serialize.py | 4 ++-- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 1dbac5fef6f..526533bda9f 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -26,7 +26,7 @@ host_array, to_frames, ) -from distributed.utils import ensure_ip, get_ip, get_ipv6 +from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6 logger = logging.getLogger(__name__) @@ -379,7 +379,9 @@ async def write(self, frames: list[bytes]) -> int: await drain_waiter # Ensure all memoryviews are in single-byte format - frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames] + frames = [ + ensure_memoryview(f) if isinstance(f, memoryview) else f for f in frames + ] nframes = len(frames) frames_nbytes = [len(f) for f in frames] @@ -847,12 +849,9 @@ def _buffer_clear(self): def _buffer_append(self, data: bytes) -> None: """Append new data to the send buffer""" - if not isinstance(data, memoryview): - data = memoryview(data) - if data.format != "B": - data = data.cast("B") - self._size += len(data) - self._buffers.append(data) + mv = ensure_memoryview(data) + self._size += len(mv) + self._buffers.append(mv) def _buffer_peek(self) -> list[memoryview]: """Get one or more buffers to write to the socket""" diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index e5cd004cca7..d88582bc67c 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -46,7 +46,7 @@ ) from distributed.protocol.utils import pack_frames_prelude, unpack_frames from distributed.system import MEMORY_LIMIT -from distributed.utils import ensure_ip, get_ip, get_ipv6, nbytes +from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes logger = logging.getLogger(__name__) @@ -305,7 +305,7 @@ async def write(self, msg, serializers=None, on_error="message"): if isinstance(each_frame, memoryview): # Make sure that `len(data) == data.nbytes` # See - each_frame = memoryview(each_frame).cast("B") + each_frame = ensure_memoryview(each_frame) stream._write_buffer.append(each_frame) stream._total_write_index += each_frame_nbytes diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 2a7ebdbf23d..77544a3f27f 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -88,7 +88,7 @@ def pickle_loads(header, frames): writeable = len(buffers) * (None,) new = [] - memoryviews = map(memoryview, buffers) + memoryviews = map(ensure_memoryview, buffers) for w, mv in zip(writeable, memoryviews): if w == mv.readonly: if w: @@ -785,7 +785,7 @@ def _serialize_memoryview(obj): @dask_deserialize.register(memoryview) def _deserialize_memoryview(header, frames): if len(frames) == 1: - out = memoryview(frames[0]).cast("B") + out = ensure_memoryview(frames[0]) else: out = memoryview(b"".join(frames)) out = out.cast(header["format"], header["shape"])