Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to asyncio streams API #869

Merged
merged 6 commits into from
May 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added uvicorn/_handlers/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions uvicorn/_handlers/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import asyncio
from typing import TYPE_CHECKING

from uvicorn.config import Config

if TYPE_CHECKING: # pragma: no cover
from uvicorn.server import ServerState


async def handle_http(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
server_state: "ServerState",
config: Config,
) -> None:
# Run transport/protocol session from streams.
#
# This is a bit fiddly, so let me explain why we do this in the first place.
#
# This was introduced to switch to the asyncio streams API while retaining our
# existing protocols-based code.
#
# The aim was to:
# * Make it easier to support alternative async libaries (all of which expose
# a streams API, rather than anything similar to asyncio's transports and
# protocols) while keeping the change footprint (and risk) at a minimum.
# * Keep a "fast track" for asyncio that's as efficient as possible, by reusing
# our asyncio-optimized protocols-based implementation.
#
# See: https://github.com/encode/uvicorn/issues/169
# See: https://github.com/encode/uvicorn/pull/869

# Use a future to coordinate between the protocol and this handler task.
# https://docs.python.org/3/library/asyncio-protocol.html#connecting-existing-sockets
loop = asyncio.get_event_loop()
connection_lost = loop.create_future()

# Switch the protocol from the stream reader to our own HTTP protocol class.
protocol = config.http_protocol_class(
config=config,
server_state=server_state,
on_connection_lost=lambda: connection_lost.set_result(True),
)
transport = writer.transport
transport.set_protocol(protocol)

# Asyncio stream servers don't `await` handler tasks (like the one we're currently
# running), so we must make sure exceptions that occur in protocols but outside the
# ASGI cycle (e.g. bugs) are properly retrieved and logged.
# Vanilla asyncio handles exceptions properly out-of-the-box, but uvloop doesn't.
# So we need to attach a callback to handle exceptions ourselves for that case.
# (It's not easy to know which loop we're effectively running on, so we attach the
# callback in all cases. In practice it won't be called on vanilla asyncio.)
task = _get_current_task()

@task.add_done_callback
def retrieve_exception(task: asyncio.Task) -> None:
exc = task.exception()

if exc is None:
return

loop.call_exception_handler(
{
"message": "Fatal error in server handler",
"exception": exc,
"transport": transport,
"protocol": protocol,
}
)
# Hang up the connection so the client doesn't wait forever.
transport.close()

# Kick off the HTTP protocol.
protocol.connection_made(transport)
euri10 marked this conversation as resolved.
Show resolved Hide resolved

# Pass any data already in the read buffer.
# The assumption here is that we haven't read any data off the stream reader
# yet: all data that the client might have already sent since the connection has
# been established is in the `_buffer`.
data = reader._buffer # type: ignore
if data:
protocol.data_received(data)

# Let the transport run in the background. When closed, this future will complete
# and we'll exit here.
await connection_lost


def _get_current_task() -> asyncio.Task:
try:
current_task = asyncio.current_task
except AttributeError: # pragma: no cover
# Python 3.6.
current_task = asyncio.Task.current_task

task = current_task()
assert task is not None
return task
13 changes: 11 additions & 2 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import http
import logging
from typing import Callable
from urllib.parse import unquote

import h11
Expand Down Expand Up @@ -34,12 +35,15 @@ def _get_status_phrase(status_code):


class H11Protocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
Expand Down Expand Up @@ -107,6 +111,9 @@ def connection_lost(self, exc):
if self.flow is not None:
self.flow.resume_writing()

if self.on_connection_lost is not None:
self.on_connection_lost()

def eof_received(self):
pass

Expand Down Expand Up @@ -253,7 +260,9 @@ def handle_upgrade(self, event):
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class(
config=self.config, server_state=self.server_state
config=self.config,
server_state=self.server_state,
on_connection_lost=self.on_connection_lost,
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down
13 changes: 11 additions & 2 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
import urllib
from typing import Callable

import httptools

Expand Down Expand Up @@ -39,12 +40,15 @@ def _get_status_line(status_code):


class HttpToolsProtocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
Expand Down Expand Up @@ -107,6 +111,9 @@ def connection_lost(self, exc):
if self.flow is not None:
self.flow.resume_writing()

if self.on_connection_lost is not None:
self.on_connection_lost()

def eof_received(self):
pass

Expand Down Expand Up @@ -166,7 +173,9 @@ def handle_upgrade(self):
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class(
config=self.config, server_state=self.server_state
config=self.config,
server_state=self.server_state,
on_connection_lost=self.on_connection_lost,
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down
8 changes: 7 additions & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import http
import logging
from typing import Callable
from urllib.parse import unquote

import websockets
Expand All @@ -23,12 +24,15 @@ def is_serving(self):


class WebSocketProtocol(websockets.WebSocketServerProtocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
Expand Down Expand Up @@ -74,6 +78,8 @@ def connection_lost(self, exc):
self.connections.remove(self)
self.handshake_completed_event.set()
super().connection_lost(exc)
if self.on_connection_lost is not None:
self.on_connection_lost()

def shutdown(self):
self.ws_server.closing = True
Expand Down
8 changes: 7 additions & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import Callable
from urllib.parse import unquote

import h11
Expand All @@ -16,12 +17,15 @@


class WSProtocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
Expand Down Expand Up @@ -65,6 +69,8 @@ def connection_lost(self, exc):
if exc is not None:
self.queue.put_nowait({"type": "websocket.disconnect"})
self.connections.remove(self)
if self.on_connection_lost is not None:
self.on_connection_lost()

def eof_received(self):
pass
Expand Down
37 changes: 20 additions & 17 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import functools
import logging
import os
import platform
Expand All @@ -13,6 +12,8 @@

import click

from ._handlers.http import handle_http

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
Expand Down Expand Up @@ -77,25 +78,26 @@ async def serve(self, sockets=None):
extra={"color_message": color_message},
)

async def startup(self, sockets=None):
async def startup(self, sockets: list = None) -> None:
await self.lifespan.startup()
if self.lifespan.should_exit:
self.should_exit = True
return

config = self.config

create_protocol = functools.partial(
config.http_protocol_class, config=config, server_state=self.server_state
)

loop = asyncio.get_event_loop()
async def handler(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
await handle_http(
reader, writer, server_state=self.server_state, config=config
)

if sockets is not None:
# Explicitly passed a list of open sockets.
# We use this when the server is run from a Gunicorn worker.

def _share_socket(sock: socket) -> socket:
def _share_socket(sock: socket.SocketType) -> socket.SocketType:
# Windows requires the socket be explicitly shared across
# multiple workers (processes).
from socket import fromshare # type: ignore
Expand All @@ -107,17 +109,17 @@ def _share_socket(sock: socket) -> socket:
for sock in sockets:
if config.workers > 1 and platform.system() == "Windows":
sock = _share_socket(sock)
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
server = await asyncio.start_server(
handler, sock=sock, ssl=config.ssl, backlog=config.backlog
)
self.servers.append(server)
listeners = sockets

elif config.fd is not None:
# Use an existing socket, from a file descriptor.
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
server = await asyncio.start_server(
handler, sock=sock, ssl=config.ssl, backlog=config.backlog
)
assert server.sockets is not None # mypy
listeners = server.sockets
Expand All @@ -128,8 +130,8 @@ def _share_socket(sock: socket) -> socket:
uds_perms = 0o666
if os.path.exists(config.uds):
uds_perms = os.stat(config.uds).st_mode
server = await loop.create_unix_server(
create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
server = await asyncio.start_unix_server(
handler, path=config.uds, ssl=config.ssl, backlog=config.backlog
)
os.chmod(config.uds, uds_perms)
assert server.sockets is not None # mypy
Expand All @@ -139,8 +141,8 @@ def _share_socket(sock: socket) -> socket:
else:
# Standard case. Create a socket from a host/port pair.
try:
server = await loop.create_server(
create_protocol,
server = await asyncio.start_server(
handler,
host=config.host,
port=config.port,
ssl=config.ssl,
Expand All @@ -150,7 +152,8 @@ def _share_socket(sock: socket) -> socket:
logger.error(exc)
await self.lifespan.shutdown()
sys.exit(1)
assert server.sockets is not None # mypy

assert server.sockets is not None
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
listeners = server.sockets
self.servers = [server]

Expand Down