Skip to content

Commit

Permalink
[WIP] Add compression, pickle protocol to comm contexts
Browse files Browse the repository at this point in the history
Currently we expect a great degree of uniformity among servers in a Dask
cluster.  This is especially apparent with pickle protocol and
compression formats.  This PR includes a handshake when connecting that
exchanges a bit of information.  Then this information is passed to
serialization functions through the existing context= mechanism.

This allows serialization functions to have more information about the
other side of the connection, and to make choices accordingly.

This isn't yet smooth, but I thought I'd throw it up early so that
people can see what I'm thinking and comment early.
  • Loading branch information
mrocklin committed Aug 5, 2020
1 parent ae5ecd5 commit a997410
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 11 deletions.
60 changes: 58 additions & 2 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ssl = None

import dask
import msgpack
from tornado import netutil
from tornado.iostream import StreamClosedError
from tornado.tcpclient import TCPClient
Expand All @@ -25,6 +26,8 @@
from .addressing import parse_host_port, unparse_host_port
from .core import Comm, Connector, Listener, CommClosedError, FatalCommClosedError
from .utils import to_frames, from_frames, get_tcp_server_address, ensure_concrete_host
from ..protocol.compression import default_compression
from ..protocol.utils import msgpack_opts


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -222,7 +225,11 @@ async def write(self, msg, serializers=None, on_error="message"):
allow_offload=self.allow_offload,
serializers=serializers,
on_error=on_error,
context={"sender": self._local_addr, "recipient": self._peer_addr},
context={
"sender": self.local_info,
"recipient": self.remote_info,
**self.handshake_options,
},
)

try:
Expand Down Expand Up @@ -355,11 +362,24 @@ async def connect(self, address, deserialize=True, **connection_args):
# The socket connect() call failed
convert_stream_closed_error(self, e)

write_future = stream.write(msgpack.dumps(handshake_info()) + b"\n")
handshake = await stream.read_until(b"\n")
handshake = msgpack.loads(handshake.strip(), **msgpack_opts)
local_address = self.prefix + get_stream_address(stream)
return self.comm_class(
comm = self.comm_class(
stream, local_address, self.prefix + address, deserialize
)

comm.remote_info = handshake
comm.remote_info["address"] = self.prefix + address
comm.local_info = handshake_info()
comm.local_info["address"] = local_address
comm.handshake_options = handshake_configuration(
comm.local_info, comm.remote_info
)

return comm


class TCPConnector(BaseTCPConnector):
prefix = "tcp://"
Expand Down Expand Up @@ -444,6 +464,20 @@ async def _handle_stream(self, stream, address):
local_address = self.prefix + get_stream_address(stream)
comm = self.comm_class(stream, local_address, address, self.deserialize)
comm.allow_offload = self.allow_offload
write_future = comm.stream.write(msgpack.dumps(handshake_info()) + b"\n")
handshake = await comm.stream.read_until(b"\n")
handshake = msgpack.loads(handshake.strip(), **msgpack_opts)

comm.remote_info = handshake
comm.remote_info["address"] = address
comm.local_info = handshake_info()
comm.local_info["address"] = local_address

comm.handshake_options = handshake_configuration(
comm.local_info, comm.remote_info
)

await write_future
await self.comm_handler(comm)

def get_host_port(self):
Expand Down Expand Up @@ -552,5 +586,27 @@ class TLSBackend(BaseTCPBackend):
_listener_class = TLSListener


def handshake_info():
return {
"compression": default_compression,
"python": tuple(sys.version_info)[:3],
}


def handshake_configuration(local, remote):
out = {}
if local["python"] >= (3, 8) and remote["python"] >= (3, 8):
out["pickle_protocol"] = 5
else:
out["pickle_protocol"] = 4

if local["compression"] == remote["compression"]:
out["compression"] = local["compression"]
else:
out["compression"] = None

return out


backends["tcp"] = TCPBackend()
backends["tls"] = TLSBackend()
4 changes: 2 additions & 2 deletions distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


async def to_frames(
msg, serializers=None, on_error="message", context=None, allow_offload=True
msg, serializers=None, on_error="message", context=None, allow_offload=True,
):
"""
Serialize a message into a list of Distributed protocol frames.
Expand All @@ -32,7 +32,7 @@ def _to_frames():
try:
return list(
protocol.dumps(
msg, serializers=serializers, on_error=on_error, context=context
msg, serializers=serializers, on_error=on_error, context=context,
)
)
except Exception as e:
Expand Down
9 changes: 7 additions & 2 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,13 @@ def byte_sample(b, size, n):
return b"".join(map(ensure_bytes, parts))


def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5):
def maybe_compress(
payload,
min_size=1e4,
sample_size=1e4,
nsamples=5,
compression=dask.config.get("distributed.comm.compression"),
):
"""
Maybe compress payload
Expand All @@ -168,7 +174,6 @@ def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5):
return the original
4. We return the compressed result
"""
compression = dask.config.get("distributed.comm.compression")
if compression == "auto":
compression = default_compression

Expand Down
9 changes: 8 additions & 1 deletion distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def dumps(msg, serializers=None, on_error="message", context=None):

out_frames = []

if context and "compression" in context:
compress_opts = {"compression": context["compression"]}
else:
compress_opts = {}

for key, (head, frames) in data.items():
if "writeable" not in head:
head["writeable"] = tuple(map(is_writeable, frames))
Expand All @@ -67,7 +72,9 @@ def dumps(msg, serializers=None, on_error="message", context=None):
):
if compression is None: # default behavior
_frames = frame_split_size(frame)
_compression, _frames = zip(*map(maybe_compress, _frames))
_compression, _frames = zip(
*[maybe_compress(frame, **compress_opts) for frame in _frames]
)
out_compression.extend(_compression)
_out_frames.extend(_frames)
else: # already specified, so pass
Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def _always_use_pickle_for(x):
return False


def dumps(x, *, buffer_callback=None):
def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL):
""" Manage between cloudpickle and pickle
1. Try pickle
2. If it is short then check if it contains __main__
3. If it is long, then first check type, then check __main__
"""
buffers = []
dump_kwargs = {"protocol": HIGHEST_PROTOCOL}
dump_kwargs = {"protocol": protocol or HIGHEST_PROTOCOL}
if HIGHEST_PROTOCOL >= 5 and buffer_callback is not None:
dump_kwargs["buffer_callback"] = buffers.append
try:
Expand Down
8 changes: 6 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,15 @@ def dask_loads(header, frames):
return loads(header, frames)


def pickle_dumps(x):
def pickle_dumps(x, context=None):
header = {"serializer": "pickle"}
frames = [None]
buffer_callback = lambda f: frames.append(memoryview(f))
frames[0] = pickle.dumps(x, buffer_callback=buffer_callback)
frames[0] = pickle.dumps(
x,
buffer_callback=buffer_callback,
protocol=context.get("pickle-protocol", None) if context else None,
)
return header, frames


Expand Down

0 comments on commit a997410

Please sign in to comment.