Skip to content

Commit

Permalink
Revert "Improve ensure_memoryview test coverage & make minor fixes (d…
Browse files Browse the repository at this point in the history
…ask#6333)"

This reverts commit 6e0fe58.
  • Loading branch information
mrocklin committed May 25, 2022
1 parent dea9ef2 commit f4bd8eb
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 73 deletions.
15 changes: 8 additions & 7 deletions distributed/comm/asyncio_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
host_array,
to_frames,
)
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6
from distributed.utils import ensure_ip, get_ip, get_ipv6

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -380,9 +380,7 @@ async def write(self, frames: list[bytes]) -> int:
await drain_waiter

# Ensure all memoryviews are in single-byte format
frames = [
ensure_memoryview(f) if isinstance(f, memoryview) else f for f in frames
]
frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames]

nframes = len(frames)
frames_nbytes = [len(f) for f in frames]
Expand Down Expand Up @@ -854,9 +852,12 @@ def _buffer_clear(self):

def _buffer_append(self, data: bytes) -> None:
"""Append new data to the send buffer"""
mv = ensure_memoryview(data)
self._size += len(mv)
self._buffers.append(mv)
if not isinstance(data, memoryview):
data = memoryview(data)
if data.format != "B":
data = data.cast("B")
self._size += len(data)
self._buffers.append(data)

def _buffer_peek(self) -> list[memoryview]:
"""Get one or more buffers to write to the socket"""
Expand Down
4 changes: 2 additions & 2 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from distributed.protocol.utils import pack_frames_prelude, unpack_frames
from distributed.system import MEMORY_LIMIT
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes
from distributed.utils import ensure_ip, get_ip, get_ipv6, nbytes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -305,7 +305,7 @@ async def write(self, msg, serializers=None, on_error="message"):
if isinstance(each_frame, memoryview):
# Make sure that `len(data) == data.nbytes`
# See <https://github.com/tornadoweb/tornado/pull/2996>
each_frame = ensure_memoryview(each_frame)
each_frame = memoryview(each_frame).cast("B")

stream._write_buffer.append(each_frame)
stream._total_write_index += each_frame_nbytes
Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def _serialize_memoryview(obj):
@dask_deserialize.register(memoryview)
def _deserialize_memoryview(header, frames):
if len(frames) == 1:
out = ensure_memoryview(frames[0])
out = memoryview(frames[0]).cast("B")
else:
out = memoryview(b"".join(frames))
out = out.cast(header["format"], header["shape"])
Expand Down
66 changes: 15 additions & 51 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,59 +248,23 @@ def test_seek_delimiter_endline():
assert f.tell() == 7


@pytest.mark.parametrize(
"data",
[
b"",
bytearray(),
b"1",
bytearray(b"1"),
memoryview(b"1"),
memoryview(bytearray(b"1")),
array("B", b"1"),
array("I", range(5)),
memoryview(b"123456")[::2],
memoryview(b"123456").cast("B", (2, 3)),
memoryview(b"0123456789").cast("B", (5, 2))[::2],
],
)
def test_ensure_memoryview(data):
data_mv = memoryview(data)
result = ensure_memoryview(data)
def test_ensure_memoryview_empty():
result = ensure_memoryview(b"")
assert isinstance(result, memoryview)
assert result.contiguous
assert result.ndim == 1
assert result.format == "B"
assert result == bytes(data_mv)
if data_mv.nbytes and data_mv.contiguous:
assert id(result.obj) == id(data_mv.obj)
assert result.readonly == data_mv.readonly
if isinstance(data, memoryview):
if data.ndim == 1 and data.format == "B":
assert id(result) == id(data)
else:
assert id(data) != id(result)
else:
assert id(result.obj) != id(data_mv.obj)
assert not result.readonly


@pytest.mark.parametrize(
"dt, nitems, shape, strides",
[
("i8", 12, (12,), (8,)),
("i8", 12, (3, 4), (32, 8)),
("i8", 12, (4, 3), (8, 32)),
("i8", 12, (3, 2), (32, 16)),
("i8", 12, (2, 3), (16, 32)),
],
)
def test_ensure_memoryview_ndarray(dt, nitems, shape, strides):
assert result == memoryview(b"")


def test_ensure_memoryview():
data = [b"1", memoryview(b"1"), bytearray(b"1"), array("B", b"1")]
for d in data:
result = ensure_memoryview(d)
assert isinstance(result, memoryview)
assert result == memoryview(b"1")


def test_ensure_memoryview_ndarray():
np = pytest.importorskip("numpy")
data = np.ndarray(
shape, dtype=dt, buffer=np.arange(nitems, dtype=dt), strides=strides
)
result = ensure_memoryview(data)
result = ensure_memoryview(np.arange(12).reshape(3, 4)[:, ::2].T)
assert isinstance(result, memoryview)
assert result.ndim == 1
assert result.format == "B"
Expand Down
17 changes: 5 additions & 12 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from functools import wraps
from hashlib import md5
from importlib.util import cache_from_source
from pickle import PickleBuffer
from time import sleep
from types import ModuleType
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -1022,19 +1021,13 @@ def ensure_memoryview(obj):

if not mv.nbytes:
# Drop `obj` reference to permit freeing underlying data
return memoryview(bytearray())
elif not mv.contiguous:
# Copy to contiguous form of expected shape & type
return memoryview(bytearray(mv))
elif mv.ndim != 1 or mv.format != "B":
return memoryview(b"")
elif mv.contiguous:
# Perform zero-copy reshape & cast
# Use `PickleBuffer.raw()` as `memoryview.cast()` fails with F-order
# Pass `mv.obj` so the created `memoryview` has that as its `obj`
# xref: https://github.com/python/cpython/issues/91484
return PickleBuffer(mv.obj).raw()
return mv.cast("B")
else:
# Return `memoryview` as it already meets requirements
return mv
# Copy to contiguous form of expected shape & type
return memoryview(mv.tobytes())


def open_port(host=""):
Expand Down

0 comments on commit f4bd8eb

Please sign in to comment.