From 072756dbb24f17935f271ae2af72a37069610d89 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 1 Apr 2021 19:16:22 -0500 Subject: [PATCH] Fix un-merged frames (#4666) * Add test for un-merged frames * Don't double-split/compress Serialized frames Previously we would re-serialize an object, even if it was a Serialized object. Instead we should just unpack its header and frames and be done. * specify num_sub_frames in all cases --- distributed/protocol/core.py | 11 +++++++---- distributed/protocol/tests/test_serialize.py | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 05a804d3b52..fb85d32fee2 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -51,11 +51,14 @@ def _encode_default(obj): if typ is Serialize: obj = obj.data offset = len(frames) - sub_header, sub_frames = serialize_and_split( - obj, serializers=serializers, on_error=on_error, context=context - ) + if typ is Serialized: + sub_header, sub_frames = obj.header, obj.frames + else: + sub_header, sub_frames = serialize_and_split( + obj, serializers=serializers, on_error=on_error, context=context + ) + _inplace_compress_frames(sub_header, sub_frames) sub_header["num-sub-frames"] = len(sub_frames) - _inplace_compress_frames(sub_header, sub_frames) frames.append( msgpack.dumps( sub_header, default=msgpack_encode_default, use_bin_type=True diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index cd2efa3bb28..cfb767bc0cb 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -13,7 +13,7 @@ from dask.utils_test import inc -from distributed import wait +from distributed import Nanny, wait from distributed.comm.utils import from_frames, to_frames from distributed.protocol import ( Serialize, @@ -499,3 +499,17 @@ def test_ser_memoryview_object(): data_in = memoryview(np.array(["hello"], dtype=object)) with pytest.raises(TypeError): serialize(data_in, on_error="raise") + + +@gen_cluster(client=True, Worker=Nanny) +async def test_large_pickled_object(c, s, a, b): + np = pytest.importorskip("numpy") + + class Data: + def __init__(self, n): + self.data = np.empty(n, dtype="u1") + + x = Data(100_000_000) + y = await c.scatter(x, workers=[a.worker_address]) + z = c.submit(lambda x: x, y, workers=[b.worker_address]) + await z