diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index cd7490be21a..b214d61c5d9 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -12,6 +12,7 @@ ssl = None import dask +import msgpack from tornado import netutil from tornado.iostream import StreamClosedError from tornado.tcpclient import TCPClient @@ -25,6 +26,8 @@ from .addressing import parse_host_port, unparse_host_port from .core import Comm, Connector, Listener, CommClosedError, FatalCommClosedError from .utils import to_frames, from_frames, get_tcp_server_address, ensure_concrete_host +from ..protocol.compression import default_compression +from ..protocol.utils import msgpack_opts logger = logging.getLogger(__name__) @@ -222,7 +225,11 @@ async def write(self, msg, serializers=None, on_error="message"): allow_offload=self.allow_offload, serializers=serializers, on_error=on_error, - context={"sender": self._local_addr, "recipient": self._peer_addr}, + context={ + "sender": self.local_info, + "recipient": self.remote_info, + **self.handshake_options, + }, ) try: @@ -355,11 +362,24 @@ async def connect(self, address, deserialize=True, **connection_args): # The socket connect() call failed convert_stream_closed_error(self, e) + write_future = stream.write(msgpack.dumps(handshake_info()) + b"\n") + handshake = await stream.read_until(b"\n") + handshake = msgpack.loads(handshake.strip(), **msgpack_opts) local_address = self.prefix + get_stream_address(stream) - return self.comm_class( + comm = self.comm_class( stream, local_address, self.prefix + address, deserialize ) + comm.remote_info = handshake + comm.remote_info["address"] = self.prefix + address + comm.local_info = handshake_info() + comm.local_info["address"] = local_address + comm.handshake_options = handshake_configuration( + comm.local_info, comm.remote_info + ) + + return comm + class TCPConnector(BaseTCPConnector): prefix = "tcp://" @@ -444,6 +464,20 @@ async def _handle_stream(self, stream, address): local_address = self.prefix + get_stream_address(stream) comm = self.comm_class(stream, local_address, address, self.deserialize) comm.allow_offload = self.allow_offload + write_future = comm.stream.write(msgpack.dumps(handshake_info()) + b"\n") + handshake = await comm.stream.read_until(b"\n") + handshake = msgpack.loads(handshake.strip(), **msgpack_opts) + + comm.remote_info = handshake + comm.remote_info["address"] = address + comm.local_info = handshake_info() + comm.local_info["address"] = local_address + + comm.handshake_options = handshake_configuration( + comm.local_info, comm.remote_info + ) + + await write_future await self.comm_handler(comm) def get_host_port(self): @@ -552,5 +586,27 @@ class TLSBackend(BaseTCPBackend): _listener_class = TLSListener +def handshake_info(): + return { + "compression": default_compression, + "python": tuple(sys.version_info)[:3], + } + + +def handshake_configuration(local, remote): + out = {} + if local["python"] >= (3, 8) and remote["python"] >= (3, 8): + out["pickle_protocol"] = 5 + else: + out["pickle_protocol"] = 4 + + if local["compression"] == remote["compression"]: + out["compression"] = local["compression"] + else: + out["compression"] = None + + return out + + backends["tcp"] = TCPBackend() backends["tls"] = TLSBackend() diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index d1a1a97e63c..eda370eed6f 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -22,7 +22,7 @@ async def to_frames( - msg, serializers=None, on_error="message", context=None, allow_offload=True + msg, serializers=None, on_error="message", context=None, allow_offload=True, ): """ Serialize a message into a list of Distributed protocol frames. @@ -32,7 +32,7 @@ def _to_frames(): try: return list( protocol.dumps( - msg, serializers=serializers, on_error=on_error, context=context + msg, serializers=serializers, on_error=on_error, context=context, ) ) except Exception as e: diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 03ebf9d5662..142be7fd868 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -157,7 +157,13 @@ def byte_sample(b, size, n): return b"".join(map(ensure_bytes, parts)) -def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5): +def maybe_compress( + payload, + min_size=1e4, + sample_size=1e4, + nsamples=5, + compression=dask.config.get("distributed.comm.compression"), +): """ Maybe compress payload @@ -168,7 +174,6 @@ def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5): return the original 4. We return the compressed result """ - compression = dask.config.get("distributed.comm.compression") if compression == "auto": compression = default_compression diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 2e67039b208..c2554dd1ef7 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -53,6 +53,11 @@ def dumps(msg, serializers=None, on_error="message", context=None): out_frames = [] + if context and "compression" in context: + compress_opts = {"compression": context["compression"]} + else: + compress_opts = {} + for key, (head, frames) in data.items(): if "writeable" not in head: head["writeable"] = tuple(map(is_writeable, frames)) @@ -67,7 +72,9 @@ def dumps(msg, serializers=None, on_error="message", context=None): ): if compression is None: # default behavior _frames = frame_split_size(frame) - _compression, _frames = zip(*map(maybe_compress, _frames)) + _compression, _frames = zip( + *[maybe_compress(frame, **compress_opts) for frame in _frames] + ) out_compression.extend(_compression) _out_frames.extend(_frames) else: # already specified, so pass diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index fd2343756a4..144a3438dab 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -33,7 +33,7 @@ def _always_use_pickle_for(x): return False -def dumps(x, *, buffer_callback=None): +def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL): """ Manage between cloudpickle and pickle 1. Try pickle @@ -41,7 +41,7 @@ def dumps(x, *, buffer_callback=None): 3. If it is long, then first check type, then check __main__ """ buffers = [] - dump_kwargs = {"protocol": HIGHEST_PROTOCOL} + dump_kwargs = {"protocol": protocol or HIGHEST_PROTOCOL} if HIGHEST_PROTOCOL >= 5 and buffer_callback is not None: dump_kwargs["buffer_callback"] = buffers.append try: diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 7aa728950d9..e3351709a62 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -54,11 +54,15 @@ def dask_loads(header, frames): return loads(header, frames) -def pickle_dumps(x): +def pickle_dumps(x, context=None): header = {"serializer": "pickle"} frames = [None] buffer_callback = lambda f: frames.append(memoryview(f)) - frames[0] = pickle.dumps(x, buffer_callback=buffer_callback) + frames[0] = pickle.dumps( + x, + buffer_callback=buffer_callback, + protocol=context.get("pickle-protocol", None) if context else None, + ) return header, frames