Skip to content

Commit

Permalink
Use ensure_memoryview in a few more places
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed May 13, 2022
1 parent 90c1c26 commit 96e2318
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
15 changes: 7 additions & 8 deletions distributed/comm/asyncio_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
host_array,
to_frames,
)
from distributed.utils import ensure_ip, get_ip, get_ipv6
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -379,7 +379,9 @@ async def write(self, frames: list[bytes]) -> int:
await drain_waiter

# Ensure all memoryviews are in single-byte format
frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames]
frames = [
ensure_memoryview(f) if isinstance(f, memoryview) else f for f in frames
]

nframes = len(frames)
frames_nbytes = [len(f) for f in frames]
Expand Down Expand Up @@ -847,12 +849,9 @@ def _buffer_clear(self):

def _buffer_append(self, data: bytes) -> None:
"""Append new data to the send buffer"""
if not isinstance(data, memoryview):
data = memoryview(data)
if data.format != "B":
data = data.cast("B")
self._size += len(data)
self._buffers.append(data)
mv = ensure_memoryview(data)
self._size += len(mv)
self._buffers.append(mv)

def _buffer_peek(self) -> list[memoryview]:
"""Get one or more buffers to write to the socket"""
Expand Down
4 changes: 2 additions & 2 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from distributed.protocol.utils import pack_frames_prelude, unpack_frames
from distributed.system import MEMORY_LIMIT
from distributed.utils import ensure_ip, get_ip, get_ipv6, nbytes
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -305,7 +305,7 @@ async def write(self, msg, serializers=None, on_error="message"):
if isinstance(each_frame, memoryview):
# Make sure that `len(data) == data.nbytes`
# See <https://github.com/tornadoweb/tornado/pull/2996>
each_frame = memoryview(each_frame).cast("B")
each_frame = ensure_memoryview(each_frame)

stream._write_buffer.append(each_frame)
stream._total_write_index += each_frame_nbytes
Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def pickle_loads(header, frames):
writeable = len(buffers) * (None,)

new = []
memoryviews = map(memoryview, buffers)
memoryviews = map(ensure_memoryview, buffers)
for w, mv in zip(writeable, memoryviews):
if w == mv.readonly:
if w:
Expand Down Expand Up @@ -785,7 +785,7 @@ def _serialize_memoryview(obj):
@dask_deserialize.register(memoryview)
def _deserialize_memoryview(header, frames):
if len(frames) == 1:
out = memoryview(frames[0]).cast("B")
out = ensure_memoryview(frames[0])
else:
out = memoryview(b"".join(frames))
out = out.cast(header["format"], header["shape"])
Expand Down

0 comments on commit 96e2318

Please sign in to comment.