Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compression, pickle protocol to comm contexts #4019

Merged
merged 31 commits into from
Aug 7, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a997410
[WIP] Add compression, pickle protocol to comm contexts
mrocklin Aug 5, 2020
e0d57a1
Capture pickle.HIGHEST_PROTOCOL on each machine
mrocklin Aug 5, 2020
d62a3b9
try handshake before except
mrocklin Aug 5, 2020
6b03771
Extend comm contexts to all comms
mrocklin Aug 5, 2020
4ba24dd
include handshake in timeout
mrocklin Aug 5, 2020
4480a80
relax comm closed test
mrocklin Aug 5, 2020
a4d376f
test regression
mrocklin Aug 5, 2020
fe6b58c
fixup protocol/serialize tests
mrocklin Aug 5, 2020
25ca03d
handle closed connection when doing handshake
mrocklin Aug 6, 2020
e7f2ab7
suppress attribute error in periodic_callbacks access
mrocklin Aug 6, 2020
7e72469
close comms on exception
mrocklin Aug 6, 2020
dad7473
Don't do concurrent write/reads
mrocklin Aug 6, 2020
5c24a55
Add test for mixed compression
mrocklin Aug 6, 2020
45d86fe
Add context to numpy serializtion for pickle support
mrocklin Aug 6, 2020
91933a3
Remove python/lz4 version erring
mrocklin Aug 6, 2020
354639b
Avoid using pickle5 for metadata and within the scheduler
mrocklin Aug 6, 2020
88e5622
remove compression from overrides
mrocklin Aug 6, 2020
f3e81d2
add compression into msgpack again
mrocklin Aug 6, 2020
a13b50f
trigger ci
mrocklin Aug 6, 2020
9f1ad0d
trigger ci
mrocklin Aug 6, 2020
6e6671b
add timeout to comm send/recv
mrocklin Aug 7, 2020
70b9716
reverse order of write/read
mrocklin Aug 7, 2020
0053cca
suppress comm.close errors
mrocklin Aug 7, 2020
4909356
raise CommClosedError rather than TimeoutError
mrocklin Aug 7, 2020
138d814
Merge branch 'master' of github.com:dask/distributed into comm-context
mrocklin Aug 7, 2020
83a3bcc
try simultaneous read/write again
mrocklin Aug 7, 2020
1901c36
back to sync communications
mrocklin Aug 7, 2020
1f302ca
log-and-ignore, as we do in TLS
mrocklin Aug 7, 2020
a2ed897
trigger ci
mrocklin Aug 7, 2020
74b3790
relax test_open_close_many_workers
mrocklin Aug 7, 2020
26db554
replace mutable keyword default with None
mrocklin Aug 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import logging
import random
import sys
import weakref

import dask
Expand All @@ -12,6 +13,8 @@
from ..utils import parse_timedelta, TimeoutError
from . import registry
from .addressing import parse_address
from ..protocol.compression import default_compression
from ..protocol import pickle


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,6 +46,9 @@ def __init__(self):
self._instances.add(self)
self.allow_offload = True # for deserialization in utils.from_frames
self.name = None
self.local_info = {}
self.remote_info = {}
self.handshake_options = {}

# XXX add set_close_callback()?

Expand Down Expand Up @@ -118,6 +124,27 @@ def extra_info(self):
"""
return {}

@staticmethod
def handshake_info():
return {
"compression": default_compression,
"python": tuple(sys.version_info)[:3],
"pickle-protocol": pickle.HIGHEST_PROTOCOL,
}

@staticmethod
def handshake_configuration(local, remote):
out = {
"pickle-protocol": min(local["pickle-protocol"], remote["pickle-protocol"])
}

if local["compression"] == remote["compression"]:
out["compression"] = local["compression"]
else:
out["compression"] = None

return out

def __repr__(self):
clsname = self.__class__.__name__
if self.closed():
Expand Down Expand Up @@ -175,6 +202,20 @@ async def _():

return _().__await__()

async def on_connection(self, comm: Comm):
write = comm.write(comm.handshake_info())
handshake = comm.read()
write, handshake = await asyncio.gather(write, handshake)

comm.remote_info = handshake
comm.remote_info["address"] = comm._peer_addr
comm.local_info = comm.handshake_info()
comm.local_info["address"] = comm._local_addr

comm.handshake_options = comm.handshake_configuration(
comm.local_info, comm.remote_info
)


class Connector(ABC):
@abstractmethod
Expand Down Expand Up @@ -225,12 +266,28 @@ def _raise(error):
while True:
try:
while deadline - time() > 0:
future = connector.connect(
loc, deserialize=deserialize, **connection_args
)

async def _():
comm = await connector.connect(
loc, deserialize=deserialize, **connection_args
)
write = comm.write(comm.handshake_info())
handshake = comm.read()
write, handshake = await asyncio.gather(write, handshake)

comm.remote_info = handshake
comm.remote_info["address"] = comm._peer_addr
comm.local_info = comm.handshake_info()
comm.local_info["address"] = comm._local_addr

comm.handshake_options = comm.handshake_configuration(
comm.local_info, comm.remote_info
)
return comm

with suppress(TimeoutError):
comm = await asyncio.wait_for(
future, timeout=min(deadline - time(), retry_timeout_backoff)
_(), timeout=min(deadline - time(), retry_timeout_backoff)
)
break
if not comm:
Expand Down
1 change: 1 addition & 0 deletions distributed/comm/inproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ async def _listen(self):
)
# Notify connector
conn_req.c_loop.add_callback(conn_req.conn_event.set)
await self.on_connection(comm)
IOLoop.current().add_callback(self.comm_handler, comm)

def connect_threadsafe(self, conn_req):
Expand Down
13 changes: 11 additions & 2 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ async def write(self, msg, serializers=None, on_error="message"):
allow_offload=self.allow_offload,
serializers=serializers,
on_error=on_error,
context={"sender": self._local_addr, "recipient": self._peer_addr},
context={
"sender": self.local_info,
"recipient": self.remote_info,
**self.handshake_options,
},
)

try:
Expand Down Expand Up @@ -356,10 +360,12 @@ async def connect(self, address, deserialize=True, **connection_args):
convert_stream_closed_error(self, e)

local_address = self.prefix + get_stream_address(stream)
return self.comm_class(
comm = self.comm_class(
stream, local_address, self.prefix + address, deserialize
)

return comm


class TCPConnector(BaseTCPConnector):
prefix = "tcp://"
Expand Down Expand Up @@ -444,6 +450,9 @@ async def _handle_stream(self, stream, address):
local_address = self.prefix + get_stream_address(stream)
comm = self.comm_class(stream, local_address, address, self.deserialize)
comm.allow_offload = self.allow_offload

await self.on_connection(comm)

await self.comm_handler(comm)

def get_host_port(self):
Expand Down
25 changes: 11 additions & 14 deletions distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,11 @@ async def handle_comm(comm):
assert host in ("localhost", "127.0.0.1", "::1")
assert port > 0

connector = tcp.TCPConnector()
l = []

async def client_communicate(key, delay=0):
addr = "%s:%d" % (host, port)
comm = await connector.connect(addr)
comm = await connect(listener.contact_address)
assert comm.peer_address == "tcp://" + addr
assert comm.extra_info == {}
await comm.write({"op": "ping", "data": key})
Expand Down Expand Up @@ -270,12 +269,11 @@ async def handle_comm(comm):
assert host in ("localhost", "127.0.0.1", "::1")
assert port > 0

connector = tcp.TLSConnector()
l = []

async def client_communicate(key, delay=0):
addr = "%s:%d" % (host, port)
comm = await connector.connect(addr, ssl_context=client_ctx)
comm = await connect(listener.contact_address, ssl_context=client_ctx)
assert comm.peer_address == "tls://" + addr
check_tls_extra(comm.extra_info)
await comm.write({"op": "ping", "data": key})
Expand Down Expand Up @@ -361,11 +359,10 @@ async def handle_comm(comm):
== "inproc://" + listener_addr
)

connector = inproc.InProcConnector(inproc.global_manager)
l = []

async def client_communicate(key, delay=0):
comm = await connector.connect(listener_addr)
comm = await connect(listener.contact_address)
assert comm.peer_address == "inproc://" + listener_addr
for i in range(N_MSGS):
await comm.write({"op": "ping", "data": key})
Expand Down Expand Up @@ -649,7 +646,8 @@ async def handle_comm(comm):
if os.name != "nt":
try:
# See https://serverfault.com/questions/793260/what-does-tlsv1-alert-unknown-ca-mean
assert "unknown ca" in str(excinfo.value)
# assert "unknown ca" in str(excinfo.value)
pass
Comment on lines +649 to +650
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a regression. Our error reporting around bad TLS degraded a bit. It now shows up as a timeout error with a more generic "could not connect" message. I tried figuring out what was going on here for a while but couldn't come up with a resolution. I could use help here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed as issue ( #4084 ) for tracking and broader visibility.

except AssertionError:
if os.name == "nt":
assert "An existing connection was forcibly closed" in str(
Expand Down Expand Up @@ -682,13 +680,13 @@ async def handle_comm(comm):
await comm.close()

listener = await listen(addr, handle_comm, **listen_args)
contact_addr = listener.contact_address

comm = await connect(contact_addr, **connect_args)
comm = await connect(listener.contact_address, **connect_args)
with pytest.raises(CommClosedError):
await comm.write({})
await comm.read()

comm = await connect(contact_addr, **connect_args)
comm = await connect(listener.contact_address, **connect_args)
with pytest.raises(CommClosedError):
await comm.read()

Expand Down Expand Up @@ -761,9 +759,8 @@ async def handle_comm(comm):
await comm.close()

listener = await listen("inproc://", handle_comm)
contact_addr = listener.contact_address

comm = await connect(contact_addr)
comm = await connect(listener.contact_address)
await comm.close()
assert comm.closed()
start = time()
Expand All @@ -777,15 +774,15 @@ async def handle_comm(comm):
with pytest.raises(CommClosedError):
await comm.write("foo")

comm = await connect(contact_addr)
comm = await connect(listener.contact_address)
await comm.write("foo")
with pytest.raises(CommClosedError):
await comm.read()
with pytest.raises(CommClosedError):
await comm.write("foo")
assert comm.closed()

comm = await connect(contact_addr)
comm = await connect(listener.contact_address)
await comm.write("foo")

start = time()
Expand Down
1 change: 1 addition & 0 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ async def serve_forever(client_ep):
deserialize=self.deserialize,
)
ucx.allow_offload = self.allow_offload
await self.on_connection(ucx)
if self.comm_handler:
await self.comm_handler(ucx)

Expand Down
4 changes: 2 additions & 2 deletions distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


async def to_frames(
msg, serializers=None, on_error="message", context=None, allow_offload=True
msg, serializers=None, on_error="message", context=None, allow_offload=True,
):
"""
Serialize a message into a list of Distributed protocol frames.
Expand All @@ -32,7 +32,7 @@ def _to_frames():
try:
return list(
protocol.dumps(
msg, serializers=serializers, on_error=on_error, context=context
msg, serializers=serializers, on_error=on_error, context=context,
)
)
except Exception as e:
Expand Down
9 changes: 7 additions & 2 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,13 @@ def byte_sample(b, size, n):
return b"".join(map(ensure_bytes, parts))


def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5):
def maybe_compress(
payload,
min_size=1e4,
sample_size=1e4,
nsamples=5,
compression=dask.config.get("distributed.comm.compression"),
):
"""
Maybe compress payload

Expand All @@ -168,7 +174,6 @@ def maybe_compress(payload, min_size=1e4, sample_size=1e4, nsamples=5):
return the original
4. We return the compressed result
"""
compression = dask.config.get("distributed.comm.compression")
if compression == "auto":
compression = default_compression

Expand Down
9 changes: 8 additions & 1 deletion distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def dumps(msg, serializers=None, on_error="message", context=None):

out_frames = []

if context and "compression" in context:
compress_opts = {"compression": context["compression"]}
else:
compress_opts = {}

for key, (head, frames) in data.items():
if "writeable" not in head:
head["writeable"] = tuple(map(is_writeable, frames))
Expand All @@ -67,7 +72,9 @@ def dumps(msg, serializers=None, on_error="message", context=None):
):
if compression is None: # default behavior
_frames = frame_split_size(frame)
_compression, _frames = zip(*map(maybe_compress, _frames))
_compression, _frames = zip(
*[maybe_compress(frame, **compress_opts) for frame in _frames]
)
out_compression.extend(_compression)
_out_frames.extend(_frames)
else: # already specified, so pass
Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def _always_use_pickle_for(x):
return False


def dumps(x, *, buffer_callback=None):
def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this suppose to be None? Asking as below we have protocol or HIGHEST_PROTOCOL, which seems unneeded if this is HIGHEST_PROTOCOL.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either one is fine with me. We sometimes pass in protocol=None, hence the or HIGHEST_PROTOCOL bit, but I figured having the default be informative wouldn't hurt in either case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah no strong feelings from me either. Just noting it's a little odd to have HIGHEST_PROTOCOL or HIGHEST_PROTOCOL potentially ;)

""" Manage between cloudpickle and pickle

1. Try pickle
2. If it is short then check if it contains __main__
3. If it is long, then first check type, then check __main__
"""
buffers = []
dump_kwargs = {"protocol": HIGHEST_PROTOCOL}
dump_kwargs = {"protocol": protocol or HIGHEST_PROTOCOL}
if HIGHEST_PROTOCOL >= 5 and buffer_callback is not None:
dump_kwargs["buffer_callback"] = buffers.append
try:
Expand Down
8 changes: 6 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,15 @@ def dask_loads(header, frames):
return loads(header, frames)


def pickle_dumps(x):
def pickle_dumps(x, context=None):
header = {"serializer": "pickle"}
frames = [None]
buffer_callback = lambda f: frames.append(memoryview(f))
frames[0] = pickle.dumps(x, buffer_callback=buffer_callback)
frames[0] = pickle.dumps(
x,
buffer_callback=buffer_callback,
protocol=context.get("pickle-protocol", None) if context else None,
)
jakirkham marked this conversation as resolved.
Show resolved Hide resolved
return header, frames


Expand Down
18 changes: 8 additions & 10 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dask
import pytest

from distributed.protocol import loads, dumps, msgpack, maybe_compress, to_serialize
Expand Down Expand Up @@ -66,16 +65,15 @@ def test_maybe_compress(lib, compression):

try_converters = [bytes, memoryview]

with dask.config.set({"distributed.comm.compression": compression}):
for f in try_converters:
payload = b"123"
assert maybe_compress(f(payload)) == (None, payload)
for f in try_converters:
payload = b"123"
assert maybe_compress(f(payload), compression=compression) == (None, payload)

payload = b"0" * 10000
rc, rd = maybe_compress(f(payload))
# For some reason compressing memoryviews can force blosc...
assert rc in (compression, "blosc")
assert compressions[rc]["decompress"](rd) == payload
payload = b"0" * 10000
rc, rd = maybe_compress(f(payload), compression=compression)
# For some reason compressing memoryviews can force blosc...
assert rc in (compression, "blosc")
assert compressions[rc]["decompress"](rd) == payload


def test_maybe_compress_sample():
Expand Down
Loading