Skip to content

Commit

Permalink
Rework BatchedSend logic
Browse files Browse the repository at this point in the history
Does away with the timeout and looking up a private attribute on IOStream.
Refs PR dask#653.
  • Loading branch information
pitrou committed Nov 14, 2016
1 parent fd64029 commit 9b75f79
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 80 deletions.
102 changes: 47 additions & 55 deletions distributed/batched.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import print_function, division, absolute_import

from datetime import timedelta
from functools import partial
import logging
from timeit import default_timer

from tornado import gen
from tornado import gen, locks
from tornado.queues import Queue
from tornado.iostream import StreamClosedError
from tornado.ioloop import PeriodicCallback, IOLoop

from .core import read, write
from .utils import log_errors
from .utils import ignoring, log_errors


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,87 +39,78 @@ class BatchedSend(object):
def __init__(self, interval, loop=None):
self.loop = loop or IOLoop.current()
self.interval = interval / 1000.
self.last_transmission = 0

self.waker = locks.Event()
self.stopped = locks.Event()
self.please_stop = False
self.buffer = []
self.stream = None
self.last_payload = []
self.last_send = gen.sleep(0)
self.message_count = 0
self.batch_count = 0
self.next_deadline = None

def start(self, stream):
self.stream = stream
if self.buffer:
self.send_next()
self.loop.add_callback(self._background_send)

def __str__(self):
return '<BatchedSend: %d in buffer>' % len(self.buffer)

__repr__ = __str__

@gen.coroutine
def send_next(self, wait=True):
try:
now = default_timer()
if wait:
wait_time = min(self.last_transmission + self.interval - now,
self.interval)
yield gen.sleep(wait_time)
yield self.last_send
self.buffer, payload = [], self.buffer
self.last_payload = payload
self.last_transmission = now
def _background_send(self):
while not self.please_stop:
if self.next_deadline is None:
yield self.waker.wait()
else:
with ignoring(gen.TimeoutError):
yield self.waker.wait(self.next_deadline)
self.waker.clear()
if self.loop.time() < self.next_deadline:
# Send interval not expired yet
continue
if not self.buffer:
# Nothing to send
self.next_deadline = None
continue
payload, self.buffer = self.buffer, []
self.batch_count += 1
self.last_send = write(self.stream, payload)
except Exception as e:
logger.exception(e)
raise
try:
yield write(self.stream, payload)
except Exception:
logger.exception("Error in batched write")
break
self.next_deadline = self.loop.time() + self.interval

@gen.coroutine
def _write(self, payload):
yield gen.sleep(0)
yield write(self.stream, payload)
self.stopped.set()

def send(self, msg):
""" Send a message to the other side
""" Schedule a message for sending to the other side
This completes quickly and synchronously
"""
try:
self.message_count += 1
if self.stream is None: # not yet started
self.buffer.append(msg)
return

if self.stream._closed:
raise StreamClosedError()
if self.stream is not None and self.stream._closed:
raise StreamClosedError()

if self.buffer:
self.buffer.append(msg)
return

# If we're new and early,
now = default_timer()
if (now < self.last_transmission + self.interval
or not self.last_send._done):
self.buffer.append(msg)
self.loop.add_callback(self.send_next)
return

self.buffer.append(msg)
self.loop.add_callback(self.send_next, wait=False)
except StreamClosedError:
raise
except Exception as e:
logger.exception(e)
self.message_count += 1
self.buffer.append(msg)
self.waker.set()

@gen.coroutine
def close(self, ignore_closed=False):
""" Flush existing messages and then close stream """
if self.stream is None:
return
self.please_stop = True
self.waker.set()
yield self.stopped.wait()
try:
if self.stream._write_buffer:
yield self.last_send
if self.buffer:
if self.next_deadline is not None:
delay = self.next_deadline - self.loop.time()
if delay > 0:
yield gen.sleep(delay)
self.buffer, payload = [], self.buffer
yield write(self.stream, payload)
except StreamClosedError:
Expand Down
52 changes: 27 additions & 25 deletions distributed/tests/test_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def handle_stream(self, stream, address):
self.count += 1
yield write(stream, msg)
except StreamClosedError as e:
pass
return

def listen(self, port=0):
while True:
Expand Down Expand Up @@ -113,7 +113,6 @@ def test_BatchedSend():
assert str(len(b.buffer)) in str(b)
assert str(len(b.buffer)) in repr(b)
b.start(stream)
yield b.last_send

yield gen.sleep(0.020)

Expand All @@ -135,41 +134,30 @@ def test_send_before_start():
stream = yield client.connect('127.0.0.1', e.port)

b = BatchedSend(interval=10)
yield b.last_send

b.send('hello')
b.send('hello')
b.send('world')

b.start(stream)
result = yield read(stream); assert result == ['hello', 'hello']
result = yield read(stream); assert result == ['hello', 'world']


@gen_test()
def test_send_after_stream_start_before_stream_finish():
def test_send_after_stream_start():
with echo_server() as e:
client = TCPClient()
stream = yield client.connect('127.0.0.1', e.port)

b = BatchedSend(interval=10)
yield b.last_send

b.start(stream)
b.send('hello')
result = yield read(stream); assert result == ['hello']


@gen_test()
def test_send_after_stream_finish():
with echo_server() as e:
client = TCPClient()
stream = yield client.connect('127.0.0.1', e.port)

b = BatchedSend(interval=10)
b.start(stream)
yield b.last_send
b.send('world')
result = yield read(stream)
if len(result) < 2:
result += yield read(stream)
assert result == ['hello', 'world']

b.send('hello')
result = yield read(stream); assert result == ['hello']

@gen_test()
def test_send_before_close():
Expand All @@ -179,7 +167,6 @@ def test_send_before_close():

b = BatchedSend(interval=10)
b.start(stream)
yield b.last_send

cnt = int(e.count)
b.send('hello')
Expand All @@ -203,14 +190,31 @@ def test_close_closed():

b = BatchedSend(interval=10)
b.start(stream)
yield b.last_send

b.send(123)
stream.close() # external closing

yield b.close(ignore_closed=True)


@gen_test()
def test_close_not_started():
b = BatchedSend(interval=10)
yield b.close()


@gen_test()
def test_close_twice():
with echo_server() as e:
client = TCPClient()
stream = yield client.connect('127.0.0.1', e.port)

b = BatchedSend(interval=10)
b.start(stream)
yield b.close()
yield b.close()


@slow
@gen_test(timeout=50)
def test_stress():
Expand Down Expand Up @@ -253,14 +257,12 @@ def test_sending_traffic_jam():

b = BatchedSend(interval=0.01)
b.start(stream)
yield b.last_send

n = 50

msg = {'x': to_serialize(data)}
for i in range(n):
b.send(assoc(msg, 'i', i))
print(len(b.buffer))
yield gen.sleep(0.001)

results = []
Expand Down

0 comments on commit 9b75f79

Please sign in to comment.