Skip to content

Commit

Permalink
Support Pickle's protocol 5 (#3784)
Browse files Browse the repository at this point in the history
* Assign `header` and `frames` before returning

* Import `HIGHEST_PROTOCOL` at top-level

* Collect keyword arguments to `*dumps`

* Assign `result` and `return` once at end

* Support out-of-band buffer serialization

* Require `cloudpickle` version `1.3.0`

Needed for out-of-band buffer handling.

* Test Pickle with out-of-band buffers

* Import `PickleBuffer` (if available)

* Test `serialize`/`deserialize` with `pickle`

* Check serialized header + frames

* Check out-of-band buffers' content

* Take `memoryview` of `PickleBuffer` for testing

* Collect buffers internally first

Before calling the user provided buffer callback, collect buffers in an
internal list. That way if the mechanism of pickling needs to be
changed, the internal list can be purged before handing these to the
user. At the end of pickling, make sure the user's buffer callback is
called on each buffer in order.

* Only collect buffers if `buffer_callback` exists

* Use `elif` instead for simplicity

* Use De Morgan's law to simplify logic

* Check `buffer_callback` before calling it

Co-authored-by: Jim Crist-Harif <[email protected]>

* Use `buffer.clear()` instead of `del buffer[:]`

Co-authored-by: Jim Crist-Harif <[email protected]>
  • Loading branch information
jakirkham and jcrist authored May 21, 2020
1 parent cc8140a commit ddc6377
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 18 deletions.
37 changes: 24 additions & 13 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import pickle
from pickle import HIGHEST_PROTOCOL

import cloudpickle

Expand All @@ -23,36 +24,46 @@ def _always_use_pickle_for(x):
return False


def dumps(x):
def dumps(x, *, buffer_callback=None):
""" Manage between cloudpickle and pickle
1. Try pickle
2. If it is short then check if it contains __main__
3. If it is long, then first check type, then check __main__
"""
buffers = []
dump_kwargs = {"protocol": HIGHEST_PROTOCOL}
if HIGHEST_PROTOCOL >= 5 and buffer_callback is not None:
dump_kwargs["buffer_callback"] = buffers.append
try:
result = pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
buffers.clear()
result = pickle.dumps(x, **dump_kwargs)
if len(result) < 1000:
if b"__main__" in result:
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
else:
return result
else:
if _always_use_pickle_for(x) or b"__main__" not in result:
return result
else:
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
elif not _always_use_pickle_for(x) and b"__main__" in result:
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
except Exception:
try:
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
except Exception as e:
logger.info("Failed to serialize %s. Exception: %s", x, e)
raise
if buffer_callback is not None:
for b in buffers:
buffer_callback(b)
return result


def loads(x):
def loads(x, *, buffers=()):
try:
return pickle.loads(x)
if buffers:
return pickle.loads(x, buffers=buffers)
else:
return pickle.loads(x)
except Exception:
logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
raise
9 changes: 7 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@ def dask_loads(header, frames):


def pickle_dumps(x):
return {"serializer": "pickle"}, [pickle.dumps(x)]
header = {"serializer": "pickle"}
frames = [None]
buffer_callback = lambda f: frames.append(memoryview(f))
frames[0] = pickle.dumps(x, buffer_callback=buffer_callback)
return header, frames


def pickle_loads(header, frames):
return pickle.loads(b"".join(frames))
x, buffers = frames[0], frames[1:]
return pickle.loads(x, buffers=buffers)


def msgpack_dumps(x):
Expand Down
82 changes: 80 additions & 2 deletions distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,93 @@

import pytest

from distributed.protocol.pickle import dumps, loads
from distributed.protocol import deserialize, serialize
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads

try:
from pickle import PickleBuffer
except ImportError:
pass


def test_pickle_data():
data = [1, b"123", "123", [123], {}, set()]
for d in data:
assert loads(dumps(d)) == d
assert deserialize(*serialize(d, serializers=("pickle",))) == d


def test_pickle_out_of_band():
class MemoryviewHolder:
def __init__(self, mv):
self.mv = memoryview(mv)

def __reduce_ex__(self, protocol):
if protocol >= 5:
return MemoryviewHolder, (PickleBuffer(self.mv),)
else:
return MemoryviewHolder, (self.mv.tobytes(),)

mv = memoryview(b"123")
mvh = MemoryviewHolder(mv)

if HIGHEST_PROTOCOL >= 5:
l = []
d = dumps(mvh, buffer_callback=l.append)
mvh2 = loads(d, buffers=l)

assert len(l) == 1
assert isinstance(l[0], PickleBuffer)
assert memoryview(l[0]) == mv
else:
mvh2 = loads(dumps(mvh))

assert isinstance(mvh2, MemoryviewHolder)
assert isinstance(mvh2.mv, memoryview)
assert mvh2.mv == mv

h, f = serialize(mvh, serializers=("pickle",))
mvh3 = deserialize(h, f)

assert isinstance(mvh3, MemoryviewHolder)
assert isinstance(mvh3.mv, memoryview)
assert mvh3.mv == mv

if HIGHEST_PROTOCOL >= 5:
assert len(f) == 2
assert isinstance(f[0], bytes)
assert isinstance(f[1], memoryview)
assert f[1] == mv
else:
assert len(f) == 1
assert isinstance(f[0], bytes)


def test_pickle_numpy():
np = pytest.importorskip("numpy")
x = np.ones(5)
assert (loads(dumps(x)) == x).all()
assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()

x = np.ones(5000)
assert (loads(dumps(x)) == x).all()
assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()

if HIGHEST_PROTOCOL >= 5:
x = np.ones(5000)

l = []
d = dumps(x, buffer_callback=l.append)
assert len(l) == 1
assert isinstance(l[0], PickleBuffer)
assert memoryview(l[0]) == memoryview(x)
assert (loads(d, buffers=l) == x).all()

h, f = serialize(x, serializers=("pickle",))
assert len(f) == 2
assert isinstance(f[0], bytes)
assert isinstance(f[1], memoryview)
assert (deserialize(h, f) == x).all()


@pytest.mark.xfail(
Expand All @@ -45,10 +116,17 @@ def funcs():

for func in funcs():
wr = weakref.ref(func)

func2 = loads(dumps(func))
wr2 = weakref.ref(func2)
assert func2(1) == func(1)
del func, func2

func3 = deserialize(*serialize(func, serializers=("pickle",)))
wr3 = weakref.ref(func3)
assert func3(1) == func(1)

del func, func2, func3
gc.collect()
assert wr() is None
assert wr2() is None
assert wr3() is None
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
click >= 6.6
cloudpickle >= 0.2.2
cloudpickle >= 1.3.0
contextvars;python_version<'3.7'
dask >= 2.9.0
msgpack >= 0.6.0
Expand Down

0 comments on commit ddc6377

Please sign in to comment.