diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 1dbac5fef6f..526533bda9f 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -26,7 +26,7 @@ host_array, to_frames, ) -from distributed.utils import ensure_ip, get_ip, get_ipv6 +from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6 logger = logging.getLogger(__name__) @@ -379,7 +379,9 @@ async def write(self, frames: list[bytes]) -> int: await drain_waiter # Ensure all memoryviews are in single-byte format - frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames] + frames = [ + ensure_memoryview(f) if isinstance(f, memoryview) else f for f in frames + ] nframes = len(frames) frames_nbytes = [len(f) for f in frames] @@ -847,12 +849,9 @@ def _buffer_clear(self): def _buffer_append(self, data: bytes) -> None: """Append new data to the send buffer""" - if not isinstance(data, memoryview): - data = memoryview(data) - if data.format != "B": - data = data.cast("B") - self._size += len(data) - self._buffers.append(data) + mv = ensure_memoryview(data) + self._size += len(mv) + self._buffers.append(mv) def _buffer_peek(self) -> list[memoryview]: """Get one or more buffers to write to the socket""" diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index e5cd004cca7..d88582bc67c 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -46,7 +46,7 @@ ) from distributed.protocol.utils import pack_frames_prelude, unpack_frames from distributed.system import MEMORY_LIMIT -from distributed.utils import ensure_ip, get_ip, get_ipv6, nbytes +from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes logger = logging.getLogger(__name__) @@ -305,7 +305,7 @@ async def write(self, msg, serializers=None, on_error="message"): if isinstance(each_frame, memoryview): # Make sure that `len(data) == data.nbytes` # See - each_frame = memoryview(each_frame).cast("B") + each_frame = ensure_memoryview(each_frame) stream._write_buffer.append(each_frame) stream._total_write_index += each_frame_nbytes diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 2a7ebdbf23d..77544a3f27f 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -88,7 +88,7 @@ def pickle_loads(header, frames): writeable = len(buffers) * (None,) new = [] - memoryviews = map(memoryview, buffers) + memoryviews = map(ensure_memoryview, buffers) for w, mv in zip(writeable, memoryviews): if w == mv.readonly: if w: @@ -785,7 +785,7 @@ def _serialize_memoryview(obj): @dask_deserialize.register(memoryview) def _deserialize_memoryview(header, frames): if len(frames) == 1: - out = memoryview(frames[0]).cast("B") + out = ensure_memoryview(frames[0]) else: out = memoryview(b"".join(frames)) out = out.cast(header["format"], header["shape"]) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 7e076c0ecb6..961bd6a9164 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -248,23 +248,59 @@ def test_seek_delimiter_endline(): assert f.tell() == 7 -def test_ensure_memoryview_empty(): - result = ensure_memoryview(b"") +@pytest.mark.parametrize( + "data", + [ + b"", + bytearray(), + b"1", + bytearray(b"1"), + memoryview(b"1"), + memoryview(bytearray(b"1")), + array("B", b"1"), + array("I", range(5)), + memoryview(b"123456")[::2], + memoryview(b"123456").cast("B", (2, 3)), + memoryview(b"0123456789").cast("B", (5, 2))[::2], + ], +) +def test_ensure_memoryview(data): + data_mv = memoryview(data) + result = ensure_memoryview(data) assert isinstance(result, memoryview) - assert result == memoryview(b"") - - -def test_ensure_memoryview(): - data = [b"1", memoryview(b"1"), bytearray(b"1"), array("B", b"1")] - for d in data: - result = ensure_memoryview(d) - assert isinstance(result, memoryview) - assert result == memoryview(b"1") - - -def test_ensure_memoryview_ndarray(): + assert result.contiguous + assert result.ndim == 1 + assert result.format == "B" + assert result == bytes(data_mv) + if data_mv.nbytes and data_mv.contiguous: + assert id(result.obj) == id(data_mv.obj) + assert result.readonly == data_mv.readonly + if isinstance(data, memoryview): + if data.ndim == 1 and data.format == "B": + assert id(result) == id(data) + else: + assert id(data) != id(result) + else: + assert id(result.obj) != id(data_mv.obj) + assert not result.readonly + + +@pytest.mark.parametrize( + "dt, nitems, shape, strides", + [ + ("i8", 12, (12,), (8,)), + ("i8", 12, (3, 4), (32, 8)), + ("i8", 12, (4, 3), (8, 32)), + ("i8", 12, (3, 2), (32, 16)), + ("i8", 12, (2, 3), (16, 32)), + ], +) +def test_ensure_memoryview_ndarray(dt, nitems, shape, strides): np = pytest.importorskip("numpy") - result = ensure_memoryview(np.arange(12).reshape(3, 4)[:, ::2].T) + data = np.ndarray( + shape, dtype=dt, buffer=np.arange(nitems, dtype=dt), strides=strides + ) + result = ensure_memoryview(data) assert isinstance(result, memoryview) assert result.ndim == 1 assert result.format == "B" diff --git a/distributed/utils.py b/distributed/utils.py index 4e7fe43dc43..e24e45b0c86 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -27,6 +27,7 @@ from functools import wraps from hashlib import md5 from importlib.util import cache_from_source +from pickle import PickleBuffer from time import sleep from types import ModuleType from typing import TYPE_CHECKING @@ -1021,13 +1022,19 @@ def ensure_memoryview(obj): if not mv.nbytes: # Drop `obj` reference to permit freeing underlying data - return memoryview(b"") - elif mv.contiguous: + return memoryview(bytearray()) + elif not mv.contiguous: + # Copy to contiguous form of expected shape & type + return memoryview(bytearray(mv)) + elif mv.ndim != 1 or mv.format != "B": # Perform zero-copy reshape & cast - return mv.cast("B") + # Use `PickleBuffer.raw()` as `memoryview.cast()` fails with F-order + # Pass `mv.obj` so the created `memoryview` has that as its `obj` + # xref: https://github.com/python/cpython/issues/91484 + return PickleBuffer(mv.obj).raw() else: - # Copy to contiguous form of expected shape & type - return memoryview(mv.tobytes()) + # Return `memoryview` as it already meets requirements + return mv def open_port(host=""):