Skip to content

Commit

Permalink
Refactor SSL shutdown process (#385)
Browse files Browse the repository at this point in the history
Co-authored-by: Yury Selivanov <[email protected]>
  • Loading branch information
fantix and 1st1 authored Feb 5, 2021
1 parent cdd2218 commit 98e113e
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 109 deletions.
197 changes: 143 additions & 54 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2609,14 +2609,18 @@ async def client(addr):

def test_remote_shutdown_receives_trailing_data(self):
if self.implementation == 'asyncio':
# this is an issue in asyncio
raise unittest.SkipTest()

CHUNK = 1024 * 128
SIZE = 32
CHUNK = 1024 * 16
SIZE = 8
count = 0

sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_sslctx = self._create_client_ssl_context()
future = None
filled = threading.Lock()
eof_received = threading.Lock()

def server(sock):
incoming = ssl.MemoryBIO()
Expand Down Expand Up @@ -2647,68 +2651,71 @@ def server(sock):
sslobj.write(b'pong')
sock.send(outgoing.read())

time.sleep(0.2) # wait for the peer to fill its backlog

# send close_notify but don't wait for response
with self.assertRaises(ssl.SSLWantReadError):
sslobj.unwrap()
sock.send(outgoing.read())

# should receive all data
data_len = 0
while True:
try:
chunk = len(sslobj.read(16384))
data_len += chunk
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))
except ssl.SSLZeroReturnError:
break

self.assertEqual(data_len, CHUNK * SIZE)

# verify that close_notify is received
sslobj.unwrap()

sock.close()
with filled:
# trigger peer's resume_writing()
incoming.write(sock.recv(65536 * 4))
while True:
try:
chunk = len(sslobj.read(16384))
data_len += chunk
except ssl.SSLWantReadError:
break

def eof_server(sock):
sock.starttls(sslctx, server_side=True)
self.assertEqual(sock.recv_all(4), b'ping')
sock.send(b'pong')
# send close_notify but don't wait for response
with self.assertRaises(ssl.SSLWantReadError):
sslobj.unwrap()
sock.send(outgoing.read())

time.sleep(0.2) # wait for the peer to fill its backlog
with eof_received:
# should receive all data
while True:
try:
chunk = len(sslobj.read(16384))
data_len += chunk
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))
except ssl.SSLZeroReturnError:
break

# send EOF
sock.shutdown(socket.SHUT_WR)
self.assertEqual(data_len, CHUNK * count)

# should receive all data
data = sock.recv_all(CHUNK * SIZE)
self.assertEqual(len(data), CHUNK * SIZE)
# verify that close_notify is received
sslobj.unwrap()

sock.close()

async def client(addr):
nonlocal future
nonlocal future, count
future = self.loop.create_future()

reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')

# fill write backlog in a hacky way - renegotiation won't help
for _ in range(SIZE):
writer.transport._test__append_write_backlog(b'x' * CHUNK)
with eof_received:
with filled:
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')

count = 0
try:
while True:
writer.write(b'x' * CHUNK)
count += 1
await asyncio.wait_for(
asyncio.ensure_future(writer.drain()), 0.5)
except asyncio.TimeoutError:
# fill write backlog in a hacky way
for _ in range(SIZE):
writer.transport._test__append_write_backlog(
b'x' * CHUNK)
count += 1

try:
data = await reader.read()
self.assertEqual(data, b'')
except (BrokenPipeError, ConnectionResetError):
pass

await future

Expand All @@ -2728,9 +2735,6 @@ def wrapper(sock):
with self.tcp_server(run(server)) as srv:
self.loop.run_until_complete(client(srv.addr))

with self.tcp_server(run(eof_server)) as srv:
self.loop.run_until_complete(client(srv.addr))

def test_connect_timeout_warning(self):
s = socket.socket(socket.AF_INET)
s.bind(('127.0.0.1', 0))
Expand Down Expand Up @@ -2842,7 +2846,7 @@ def server(sock):
sock.shutdown(socket.SHUT_WR)
loop.call_soon_threadsafe(eof.set)
# make sure we have enough time to reproduce the issue
assert sock.recv(1024) == b''
self.assertEqual(sock.recv(1024), b'')
sock.close()

class Protocol(asyncio.Protocol):
Expand Down Expand Up @@ -2875,7 +2879,92 @@ async def client(addr):
tr.resume_reading()
await pr.fut
tr.close()
assert extra == b'extra bytes'
if self.implementation != 'asyncio':
# extra data received after transport.close() should be
# ignored - this is likely a bug in asyncio
self.assertIsNone(extra)

with self.tcp_server(server) as srv:
loop.run_until_complete(client(srv.addr))

def test_shutdown_while_pause_reading(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

loop = self.loop
conn_made = loop.create_future()
eof_recvd = loop.create_future()
conn_lost = loop.create_future()
data_recv = False

def server(sock):
sslctx = self._create_server_ssl_context(self.ONLYCERT,
self.ONLYKEY)
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)

while True:
try:
sslobj.do_handshake()
sslobj.write(b'trailing data')
break
except ssl.SSLWantReadError:
if outgoing.pending:
sock.send(outgoing.read())
incoming.write(sock.recv(16384))
if outgoing.pending:
sock.send(outgoing.read())

while True:
try:
self.assertEqual(sslobj.read(), b'') # close_notify
break
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))

while True:
try:
sslobj.unwrap()
except ssl.SSLWantReadError:
if outgoing.pending:
sock.send(outgoing.read())
# incoming.write(sock.recv(16384))
else:
if outgoing.pending:
sock.send(outgoing.read())
break

self.assertEqual(sock.recv(16384), b'') # socket closed

class Protocol(asyncio.Protocol):
def connection_made(self, transport):
conn_made.set_result(None)

def data_received(self, data):
nonlocal data_recv
data_recv = True

def eof_received(self):
eof_recvd.set_result(None)

def connection_lost(self, exc):
if exc is None:
conn_lost.set_result(None)
else:
conn_lost.set_exception(exc)

async def client(addr):
ctx = self._create_client_ssl_context()
tr, _ = await loop.create_connection(Protocol, *addr, ssl=ctx)
await conn_made
self.assertFalse(data_recv)

tr.pause_reading()
tr.close()

await eof_recvd
await conn_lost

with self.tcp_server(server) as srv:
loop.run_until_complete(client(srv.addr))
Expand Down
1 change: 1 addition & 0 deletions uvloop/includes/stdlib.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ cdef ssl_MemoryBIO = ssl.MemoryBIO
cdef ssl_create_default_context = ssl.create_default_context
cdef ssl_SSLError = ssl.SSLError
cdef ssl_SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
cdef ssl_SSLZeroReturnError = ssl.SSLZeroReturnError
cdef ssl_CertificateError = ssl.CertificateError
cdef int ssl_SSL_ERROR_WANT_READ = ssl.SSL_ERROR_WANT_READ
cdef int ssl_SSL_ERROR_WANT_WRITE = ssl.SSL_ERROR_WANT_WRITE
Expand Down
6 changes: 3 additions & 3 deletions uvloop/sslproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ cdef enum AppProtocolState:

cdef class _SSLProtocolTransport:
cdef:
object _loop
Loop _loop
SSLProtocol _ssl_protocol
bint _closed

Expand All @@ -41,7 +41,7 @@ cdef class SSLProtocol:
size_t _write_buffer_size

object _waiter
object _loop
Loop _loop
_SSLProtocolTransport _app_transport
bint _app_transport_created

Expand All @@ -65,7 +65,6 @@ cdef class SSLProtocol:

bint _ssl_writing_paused
bint _app_reading_paused
bint _eof_received

size_t _incoming_high_water
size_t _incoming_low_water
Expand Down Expand Up @@ -100,6 +99,7 @@ cdef class SSLProtocol:

cdef _start_shutdown(self)
cdef _check_shutdown_timeout(self)
cdef _do_read_into_void(self)
cdef _do_flush(self)
cdef _do_shutdown(self)
cdef _on_shutdown_complete(self, shutdown_exc)
Expand Down
Loading

0 comments on commit 98e113e

Please sign in to comment.