Skip to content

Commit

Permalink
some sleep in tcp
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Nov 4, 2021
1 parent 7adc0c0 commit 5795c5b
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 29 deletions.
58 changes: 46 additions & 12 deletions distributed/batched.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from collections import deque
from uuid import uuid4

from tornado import gen, locks
from tornado.ioloop import IOLoop

import dask
from dask.utils import parse_timedelta

from .comm import Comm
from .core import CommClosedError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,6 +45,7 @@ def __init__(self, interval, loop=None, serializers=None):
self.interval = parse_timedelta(interval, default="ms")
self.waker = locks.Event()
self.stopped = locks.Event()
self.stopped.set()
self.please_stop = False
self.buffer = []
self.comm = None
Expand All @@ -56,13 +59,33 @@ def __init__(self, interval, loop=None, serializers=None):
self.serializers = serializers
self._consecutive_failures = 0

def start(self, comm):
self.comm = comm
self.please_stop = False
self.loop.add_callback(self._background_send)
def start(self, comm: Comm):
"""
Start the BatchedSend by providing an open Comm object.
Calling this again on an already started BatchedSend will raise a
`RuntimeError` if the provided Comm is different to the current one. If
the provided Comm is identical this is a noop.
In case the BatchedSend was already closed, this will use the newly
provided Comm to submit any accumulated messages in the buffer.
"""
if self.closed():
if comm.closed():
raise RuntimeError("Comm already closed.")
self.comm = comm
self.please_stop = False
self.loop.add_callback(self._background_send)
elif self.comm is not comm:
raise RuntimeError("BatchedSend already started.")

def closed(self):
return self.comm and self.comm.closed()
"""True if the BatchedSend hasn't been started or has been closed
already."""
if self.comm is None or self.comm.closed():
return True
else:
return False

def __repr__(self):
if self.closed():
Expand Down Expand Up @@ -99,7 +122,8 @@ def _background_send(self):
else:
self.recent_message_log.append("large-message")
self.byte_count += nbytes
payload = [] # lose ref

payload.clear() # lose ref
except CommClosedError:
logger.info("Batched Comm Closed %r", self.comm, exc_info=True)
break
Expand All @@ -121,21 +145,27 @@ def _background_send(self):
self.stopped.set()
return

self.stopped.set()
# If we've reached here, it means `break` was hit above and
# there was an exception when using `comm`.
# We can't close gracefully via `.close()` since we can't send messages.
# So we just abort.
# To propagate exceptions, we rely on subsequent `BatchedSend.send`
# calls to raise CommClosedErrors.
self.stopped.set()
self.abort()

if self.comm:
self.comm.abort()
yield self.close()

def send(self, *msgs):
"""Schedule a message for sending to the other side
This completes quickly and synchronously
This completes quickly and synchronously.
If the BatchedSend or Comm is already closed, this raises a
CommClosedError and does not accept any further messages to the buffer.
"""
if self.comm is not None and self.comm.closed():
if self.closed():
raise CommClosedError(f"Comm {self.comm!r} already closed.")

self.message_count += len(msgs)
Expand All @@ -146,7 +176,7 @@ def send(self, *msgs):

@gen.coroutine
def close(self, timeout=None):
"""Flush existing messages and then close comm
"""Flush existing messages and then close Comm
If set, raises `tornado.util.TimeoutError` after a timeout.
"""
Expand All @@ -155,8 +185,8 @@ def close(self, timeout=None):
self.please_stop = True
self.waker.set()
yield self.stopped.wait(timeout=timeout)
payload = []
if not self.comm.closed():
payload = []
try:
if self.buffer:
self.buffer, payload = [], self.buffer
Expand All @@ -170,9 +200,13 @@ def close(self, timeout=None):
yield self.comm.close()

def abort(self):
"""Close the BatchedSend immediately, without waiting for any pending
operations to complete. Buffered data will be lost."""
if self.comm is None:
return
self.buffer = []
self.please_stop = True
self.waker.set()
if not self.comm.closed():
self.comm.abort()
self.comm = None
14 changes: 13 additions & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import ctypes
import errno
import functools
Expand Down Expand Up @@ -290,7 +292,15 @@ async def write(self, msg, serializers=None, on_error="message"):
stream._total_write_index += each_frame_nbytes

# start writing frames
stream.write(b"")
await stream.write(b"")
# FIXME: How do I test this? Why is the stream closed _sometimes_?
# Diving into tornado, so far, I can only confirm that once the
# write future has been awaited, the entire buffer has been written
# to the socket. Not sure if one loop iteration is sufficient in
# general or just sufficient for the local tests I've been running
await asyncio.sleep(0)
if stream.closed():
raise StreamClosedError()
except StreamClosedError as e:
self.stream = None
self._closed = True
Expand Down Expand Up @@ -333,6 +343,8 @@ def abort(self):
stream.close()

def closed(self):
if self.stream and self.stream.closed():
self.abort()
return self._closed

@property
Expand Down
2 changes: 1 addition & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5510,7 +5510,7 @@ async def handle_worker(self, comm=None, worker=None):
await self.handle_stream(comm=comm, extra={"worker": worker})
finally:
if worker in self.stream_comms:
worker_comm.abort()
await worker_comm.close()
await self.remove_worker(address=worker)

def add_plugin(
Expand Down
71 changes: 56 additions & 15 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3197,35 +3197,76 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(


@gen_cluster(nthreads=[("", 1)])
async def test_dont_loose_payload_reconnect(s, w):
@pytest.mark.parametrize(
"sender",
[
"worker",
"scheduler",
],
)
async def test_dont_loose_payload_reconnect_worker_sends(s, w, sender):
"""Ensure that payload of a BatchedSend is not lost if a worker reconnects"""
s.count = 0
while w.heartbeat_active:
await asyncio.sleep(0.1)
w.heartbeat_active = True

if sender == "worker":
sender = w
sender_stream = w.batched_stream
receiver = s
receiver_stream = s.stream_comms[w.address]
else:
sender = s
sender_stream = s.stream_comms[w.address]
receiver = w
receiver_stream = w.batched_stream

receiver.count = 0

def receive(worker, msg):
s.count += 1
receiver.count += 1

receiver.stream_handlers["receive-msg"] = receive

s.stream_handlers["receive-msg"] = receive
w.batched_stream.next_deadline = w.loop.time() + 10_000
# Wait until the buffer is empty such that we can start cleanly (e.g.
# hearbeats, status updates, etc.)
while sender_stream.buffer:
await asyncio.sleep(0.01)

for x in range(100):
w.batched_stream.send({"op": "receive-msg", "msg": x})
sender_stream.send({"op": "receive-msg", "msg": x})

receiver_stream.comm.abort()

# Batch_count increases with every attempt. Therefore, if it increases we
# know the background send ran once
before_batch_count = sender_stream.batch_count

await s.stream_comms[w.address].comm.close()
while not w.batched_stream.comm.closed():
before = sender_stream.buffer.copy()
assert len(sender_stream.buffer) == 100

while sender_stream.batch_count == before_batch_count:
await asyncio.sleep(0.1)
before = w.batched_stream.buffer.copy()
w.batched_stream.next_deadline = w.loop.time()
assert len(w.batched_stream.buffer) == 100
with captured_logger("distributed.batched") as caplog:
await w.batched_stream._background_send()

assert "Batched Comm Closed" in caplog.getvalue()
after = w.batched_stream.buffer.copy()
# At the time of send, we already know it is failed and the caller should
# handle this exception and trigger a reconnect
# TODO: Is the transition engine robust to this??
new_message = {"op": "receive-msg", "msg": 100}
with pytest.raises(CommClosedError):
sender_stream.send(new_message)
assert new_message not in sender_stream.buffer

after = sender_stream.buffer.copy()

# Payload that couldn't be submitted is prepended
assert len(after) >= len(before)
assert after[: len(before)] == before

# No message received, yet
assert s.count == 0

# Now, reconnect and everythign should stabilize again
w.heartbeat_active = False
await w.heartbeat()
while not s.count == 100:
await asyncio.sleep(0.1)

0 comments on commit 5795c5b

Please sign in to comment.