From 7b202ccf644396a4ae8ff8053578058f5023ac63 Mon Sep 17 00:00:00 2001 From: Versus Void Date: Wed, 18 Dec 2019 15:17:51 +0300 Subject: [PATCH] Restore context on listen in UVStreamServer. Fix #305 --- tests/test_context.py | 46 +++++++++++++++++++++++++++++++++ uvloop/handles/streamserver.pxd | 1 + uvloop/handles/streamserver.pyx | 5 +++- 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/test_context.py b/tests/test_context.py index 4d3b12ce..ce0a456a 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -2,11 +2,23 @@ import contextvars import decimal import random +import socket import weakref from uvloop import _testbase as tb +class _Protocol(asyncio.Protocol): + def __init__(self, *, loop=None): + self.done = asyncio.Future(loop=loop) + + def connection_lost(self, exc): + if exc is None: + self.done.set_result(None) + else: + self.done.set_exception(exc) + + class _ContextBaseTests: def test_task_decimal_context(self): @@ -126,6 +138,40 @@ async def main(): del tracked self.assertIsNone(ref()) + def test_create_server_protocol_factory_context(self): + cvar = contextvars.ContextVar('cvar', default='outer') + factory_called_future = self.loop.create_future() + proto = _Protocol(loop=self.loop) + + def factory(): + try: + self.assertEqual(cvar.get(), 'inner') + except Exception as e: + factory_called_future.set_exception(e) + else: + factory_called_future.set_result(None) + + return proto + + async def test(): + cvar.set('inner') + port = tb.find_free_port() + srv = await self.loop.create_server(factory, '127.0.0.1', port) + + s = socket.socket(socket.AF_INET) + with s: + s.setblocking(False) + await self.loop.sock_connect(s, ('127.0.0.1', port)) + + try: + await factory_called_future + finally: + srv.close() + await proto.done + await srv.wait_closed() + + self.loop.run_until_complete(test()) + class Test_UV_Context(_ContextBaseTests, tb.UVTestCase): pass diff --git a/uvloop/handles/streamserver.pxd b/uvloop/handles/streamserver.pxd index b2ab1887..e2093316 100644 --- a/uvloop/handles/streamserver.pxd +++ b/uvloop/handles/streamserver.pxd @@ -7,6 +7,7 @@ cdef class UVStreamServer(UVSocketHandle): object protocol_factory bint opened Server _server + object listen_context # All "inline" methods are final diff --git a/uvloop/handles/streamserver.pyx b/uvloop/handles/streamserver.pyx index 7b2258dd..c1f4cd4e 100644 --- a/uvloop/handles/streamserver.pyx +++ b/uvloop/handles/streamserver.pyx @@ -8,6 +8,7 @@ cdef class UVStreamServer(UVSocketHandle): self.ssl_handshake_timeout = None self.ssl_shutdown_timeout = None self.protocol_factory = None + self.listen_context = None cdef inline _init(self, Loop loop, object protocol_factory, Server server, @@ -53,6 +54,8 @@ cdef class UVStreamServer(UVSocketHandle): if self.opened != 1: raise RuntimeError('unopened TCPServer') + self.listen_context = Context_CopyCurrent() + err = uv.uv_listen( self._handle, self.backlog, __uv_streamserver_on_listen) @@ -64,7 +67,7 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline _on_listen(self): cdef UVStream client - protocol = self.protocol_factory() + protocol = self.listen_context.run(self.protocol_factory) if self.ssl is None: client = self._make_new_transport(protocol, None)