Skip to content

Commit

Permalink
Use a callback to close TCP Comms, rather than check every time (#4453)
Browse files Browse the repository at this point in the history
In a recent trace this relieves about 30ms of a 3s shuffle computation
resulting in around a 1% overall speedup
  • Loading branch information
mrocklin authored Feb 5, 2021
1 parent cbb88bd commit 7d2a22f
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import errno
import functools
import logging
import socket
from ssl import SSLError
Expand Down Expand Up @@ -126,8 +127,16 @@ def convert_stream_closed_error(obj, exc):
raise CommClosedError("in %s: %s" % (obj, exc)) from exc


def _do_nothing():
pass
def _close_comm(ref):
"""Callback to close Dask Comm when Tornado Stream closes
Parameters
----------
ref: weak reference to a Dask comm
"""
comm = ref()
if comm:
comm._closed = True


class TCP(Comm):
Expand All @@ -136,6 +145,7 @@ class TCP(Comm):
"""

def __init__(self, stream, local_addr, peer_addr, deserialize=True):
self._closed = False
Comm.__init__(self)
self._local_addr = local_addr
self._peer_addr = peer_addr
Expand All @@ -145,18 +155,12 @@ def __init__(self, stream, local_addr, peer_addr, deserialize=True):
self._finalizer.atexit = False
self._extra = {}

ref = weakref.ref(self)

stream.set_close_callback(functools.partial(_close_comm, ref))

stream.set_nodelay(True)
set_tcp_timeout(stream)
# set a close callback, to make `self.stream.closed()` more reliable.
# Background: if `stream` is unused (e.g. because it's in `ConnectionPool.available`),
# the underlying fd is not watched for changes. In this case, even if the
# connection is actively closed by the remote end, `self.closed()` would still return `False`.
# Registering a closed callback will make tornado register the underlying fd
# for changes, and this would be reflected in `self.closed()` even without reading/writing.
#
# Use a global method (instead of a lambda) to avoid creating a reference
# to the local scope.
stream.set_close_callback(_do_nothing)
self._read_extra()

def _read_extra(self):
Expand Down Expand Up @@ -198,6 +202,7 @@ async def read(self, deserializers=None):
frames.append(frame)
except StreamClosedError as e:
self.stream = None
self._closed = True
if not shutting_down():
convert_stream_closed_error(self, e)
except Exception:
Expand Down Expand Up @@ -261,6 +266,7 @@ async def write(self, msg, serializers=None, on_error="message"):
bytes_since_last_yield = 0
except StreamClosedError as e:
self.stream = None
self._closed = True
if not shutting_down():
convert_stream_closed_error(self, e)
except Exception:
Expand All @@ -281,6 +287,7 @@ def close(self):
# Task was destroyed but it is pending!
# Triggered by distributed.deploy.tests.test_local::test_silent_startup
stream, self.stream = self.stream, None
self._closed = True
if stream is not None and not stream.closed():
try:
# Flush the stream's write buffer by waiting for a last write.
Expand All @@ -295,12 +302,13 @@ def close(self):

def abort(self):
stream, self.stream = self.stream, None
self._closed = True
if stream is not None and not stream.closed():
self._finalizer.detach()
stream.close()

def closed(self):
return self.stream is None or self.stream.closed()
return self._closed

@property
def extra_info(self):
Expand Down

0 comments on commit 7d2a22f

Please sign in to comment.