Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ensure_memoryview test coverage & make minor fixes #6333

Merged
merged 7 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions distributed/comm/asyncio_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
host_array,
to_frames,
)
from distributed.utils import ensure_ip, get_ip, get_ipv6
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -379,7 +379,9 @@ async def write(self, frames: list[bytes]) -> int:
await drain_waiter

# Ensure all memoryviews are in single-byte format
frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames]
frames = [
ensure_memoryview(f) if isinstance(f, memoryview) else f for f in frames
]

nframes = len(frames)
frames_nbytes = [len(f) for f in frames]
Expand Down Expand Up @@ -847,12 +849,9 @@ def _buffer_clear(self):

def _buffer_append(self, data: bytes) -> None:
"""Append new data to the send buffer"""
if not isinstance(data, memoryview):
data = memoryview(data)
if data.format != "B":
data = data.cast("B")
self._size += len(data)
self._buffers.append(data)
mv = ensure_memoryview(data)
self._size += len(mv)
self._buffers.append(mv)

def _buffer_peek(self) -> list[memoryview]:
"""Get one or more buffers to write to the socket"""
Expand Down
4 changes: 2 additions & 2 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from distributed.protocol.utils import pack_frames_prelude, unpack_frames
from distributed.system import MEMORY_LIMIT
from distributed.utils import ensure_ip, get_ip, get_ipv6, nbytes
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -305,7 +305,7 @@ async def write(self, msg, serializers=None, on_error="message"):
if isinstance(each_frame, memoryview):
# Make sure that `len(data) == data.nbytes`
# See <https://github.com/tornadoweb/tornado/pull/2996>
each_frame = memoryview(each_frame).cast("B")
each_frame = ensure_memoryview(each_frame)

stream._write_buffer.append(each_frame)
stream._total_write_index += each_frame_nbytes
Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def pickle_loads(header, frames):
writeable = len(buffers) * (None,)

new = []
memoryviews = map(memoryview, buffers)
memoryviews = map(ensure_memoryview, buffers)
for w, mv in zip(writeable, memoryviews):
if w == mv.readonly:
if w:
Expand Down Expand Up @@ -785,7 +785,7 @@ def _serialize_memoryview(obj):
@dask_deserialize.register(memoryview)
def _deserialize_memoryview(header, frames):
if len(frames) == 1:
out = memoryview(frames[0]).cast("B")
out = ensure_memoryview(frames[0])
else:
out = memoryview(b"".join(frames))
out = out.cast(header["format"], header["shape"])
Expand Down
66 changes: 51 additions & 15 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,23 +248,59 @@ def test_seek_delimiter_endline():
assert f.tell() == 7


def test_ensure_memoryview_empty():
result = ensure_memoryview(b"")
@pytest.mark.parametrize(
"data",
[
b"",
bytearray(),
b"1",
bytearray(b"1"),
memoryview(b"1"),
memoryview(bytearray(b"1")),
array("B", b"1"),
array("I", range(5)),
memoryview(b"123456")[::2],
memoryview(b"123456").cast("B", (2, 3)),
memoryview(b"0123456789").cast("B", (5, 2))[::2],
],
)
def test_ensure_memoryview(data):
data_mv = memoryview(data)
result = ensure_memoryview(data)
assert isinstance(result, memoryview)
assert result == memoryview(b"")


def test_ensure_memoryview():
data = [b"1", memoryview(b"1"), bytearray(b"1"), array("B", b"1")]
for d in data:
result = ensure_memoryview(d)
assert isinstance(result, memoryview)
assert result == memoryview(b"1")


def test_ensure_memoryview_ndarray():
assert result.contiguous
assert result.ndim == 1
assert result.format == "B"
assert result == bytes(data_mv)
if data_mv.nbytes and data_mv.contiguous:
assert id(result.obj) == id(data_mv.obj)
assert result.readonly == data_mv.readonly
if isinstance(data, memoryview):
if data.ndim == 1 and data.format == "B":
assert id(result) == id(data)
else:
assert id(data) != id(result)
else:
assert id(result.obj) != id(data_mv.obj)
assert not result.readonly


@pytest.mark.parametrize(
"dt, nitems, shape, strides",
[
("i8", 12, (12,), (8,)),
("i8", 12, (3, 4), (32, 8)),
("i8", 12, (4, 3), (8, 32)),
("i8", 12, (3, 2), (32, 16)),
("i8", 12, (2, 3), (16, 32)),
],
)
def test_ensure_memoryview_ndarray(dt, nitems, shape, strides):
np = pytest.importorskip("numpy")
result = ensure_memoryview(np.arange(12).reshape(3, 4)[:, ::2].T)
data = np.ndarray(
shape, dtype=dt, buffer=np.arange(nitems, dtype=dt), strides=strides
)
result = ensure_memoryview(data)
assert isinstance(result, memoryview)
assert result.ndim == 1
assert result.format == "B"
Expand Down
17 changes: 12 additions & 5 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from functools import wraps
from hashlib import md5
from importlib.util import cache_from_source
from pickle import PickleBuffer
from time import sleep
from types import ModuleType
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -1021,13 +1022,19 @@ def ensure_memoryview(obj):

if not mv.nbytes:
# Drop `obj` reference to permit freeing underlying data
return memoryview(b"")
elif mv.contiguous:
return memoryview(bytearray())
elif not mv.contiguous:
# Copy to contiguous form of expected shape & type
return memoryview(bytearray(mv))
elif mv.ndim != 1 or mv.format != "B":
# Perform zero-copy reshape & cast
return mv.cast("B")
# Use `PickleBuffer.raw()` as `memoryview.cast()` fails with F-order
# Pass `mv.obj` so the created `memoryview` has that as its `obj`
# xref: https://github.com/python/cpython/issues/91484
return PickleBuffer(mv.obj).raw()
else:
# Copy to contiguous form of expected shape & type
return memoryview(mv.tobytes())
# Return `memoryview` as it already meets requirements
return mv


def open_port(host=""):
Expand Down