diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 1dbac5fef6f..526533bda9f 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -26,7 +26,7 @@ host_array, to_frames, ) -from distributed.utils import ensure_ip, get_ip, get_ipv6 +from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6 logger = logging.getLogger(__name__) @@ -379,7 +379,9 @@ async def write(self, frames: list[bytes]) -> int: await drain_waiter # Ensure all memoryviews are in single-byte format - frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames] + frames = [ + ensure_memoryview(f) if isinstance(f, memoryview) else f for f in frames + ] nframes = len(frames) frames_nbytes = [len(f) for f in frames] @@ -847,12 +849,9 @@ def _buffer_clear(self): def _buffer_append(self, data: bytes) -> None: """Append new data to the send buffer""" - if not isinstance(data, memoryview): - data = memoryview(data) - if data.format != "B": - data = data.cast("B") - self._size += len(data) - self._buffers.append(data) + mv = ensure_memoryview(data) + self._size += len(mv) + self._buffers.append(mv) def _buffer_peek(self) -> list[memoryview]: """Get one or more buffers to write to the socket""" diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index e5cd004cca7..d88582bc67c 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -46,7 +46,7 @@ ) from distributed.protocol.utils import pack_frames_prelude, unpack_frames from distributed.system import MEMORY_LIMIT -from distributed.utils import ensure_ip, get_ip, get_ipv6, nbytes +from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes logger = logging.getLogger(__name__) @@ -305,7 +305,7 @@ async def write(self, msg, serializers=None, on_error="message"): if isinstance(each_frame, memoryview): # Make sure that `len(data) == data.nbytes` # See - each_frame = memoryview(each_frame).cast("B") + each_frame = ensure_memoryview(each_frame) stream._write_buffer.append(each_frame) stream._total_write_index += each_frame_nbytes diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 2a7ebdbf23d..6dfa32244f0 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -785,7 +785,7 @@ def _serialize_memoryview(obj): @dask_deserialize.register(memoryview) def _deserialize_memoryview(header, frames): if len(frames) == 1: - out = memoryview(frames[0]).cast("B") + out = ensure_memoryview(frames[0]) else: out = memoryview(b"".join(frames)) out = out.cast(header["format"], header["shape"])