Skip to content

Commit

Permalink
Test Pickle with out-of-band buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed May 7, 2020
1 parent 1eaa8cc commit e30341e
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from distributed.protocol.pickle import dumps, loads
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads


def test_pickle_data():
Expand All @@ -15,6 +15,37 @@ def test_pickle_data():
assert loads(dumps(d)) == d


def test_pickle_out_of_band():
try:
from pickle import PickleBuffer
except ImportError:
pass

class MemoryviewHolder:
def __init__(self, mv):
self.mv = memoryview(mv)

def __reduce_ex__(self, protocol):
if protocol >= 5:
return MemoryviewHolder, (PickleBuffer(self.mv),)
else:
return MemoryviewHolder, (self.mv.tobytes(),)

mv = memoryview(b"123")
mvh = MemoryviewHolder(mv)

if HIGHEST_PROTOCOL >= 5:
l = []
d = dumps(mvh, buffer_callback=l.append)
mvh2 = loads(d, buffers=l)
else:
mvh2 = loads(dumps(mvh))

assert isinstance(mvh2, MemoryviewHolder)
assert isinstance(mvh2.mv, memoryview)
assert mvh2.mv == mv


def test_pickle_numpy():
np = pytest.importorskip("numpy")
x = np.ones(5)
Expand All @@ -23,6 +54,12 @@ def test_pickle_numpy():
x = np.ones(5000)
assert (loads(dumps(x)) == x).all()

if HIGHEST_PROTOCOL >= 5:
x = np.ones(5000)
l = []
d = dumps(x, buffer_callback=l.append)
assert (loads(d, buffers=l) == x).all()


@pytest.mark.xfail(
sys.version_info[:2] == (3, 8),
Expand Down

0 comments on commit e30341e

Please sign in to comment.