Skip to content

Commit

Permalink
move ctx to conn_info
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Mar 15, 2021
1 parent 3f82c62 commit 4e13e6e
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 30 deletions.
2 changes: 2 additions & 0 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Sanic(BaseSanic):
"""

__fake_slots__ = (
"_app_registry",
"_asgi_client",
"_blueprint_order",
"_future_routes",
Expand Down Expand Up @@ -110,6 +111,7 @@ class Sanic(BaseSanic):
"signal_router",
"sock",
"strict_slashes",
"test_mode",
"websocket_enabled",
"websocket_tasks",
)
Expand Down
4 changes: 1 addition & 3 deletions sanic/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,9 @@ async def receive_body(self):
self.body = b"".join([data async for data in self.stream])

@property
def connection(self):
def protocol(self):
if not self._protocol:
self._protocol = self.transport.get_protocol()
if not hasattr(self._protocol, "ctx"):
self._protocol.ctx = SimpleNamespace()
return self._protocol

@property
Expand Down
18 changes: 12 additions & 6 deletions sanic/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from ssl import SSLContext
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -62,24 +63,28 @@ class ConnInfo:
"""

__slots__ = (
"sockname",
"client_port",
"client",
"ctx",
"peername",
"server",
"server_port",
"client",
"client_port",
"server",
"sockname",
"ssl",
)

def __init__(self, transport: TransportProtocol, unix=None):
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))
self.ctx = SimpleNamespace()
self.peername = None
self.server = self.client = ""
self.server_port = self.client_port = 0
self.peername = None
self.sockname = addr = transport.get_extra_info("sockname")
self.ssl: bool = bool(transport.get_extra_info("sslcontext"))

if isinstance(addr, str): # UNIX socket
self.server = unix or addr
return

# IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid)
if isinstance(addr, tuple):
self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
Expand All @@ -88,6 +93,7 @@ def __init__(self, transport: TransportProtocol, unix=None):
if addr[1] != (443 if self.ssl else 80):
self.server = f"{self.server}:{addr[1]}"
self.peername = addr = transport.get_extra_info("peername")

if isinstance(addr, tuple):
self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]"
self.client_port = addr[1]
Expand Down
35 changes: 14 additions & 21 deletions tests/test_keep_alive_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@

PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port

from httpcore._async.base import ConnectionState
from httpcore._async.connection import AsyncHTTPConnection
from httpcore._types import Origin


class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool):
last_reused_connection = None
Expand Down Expand Up @@ -210,13 +206,13 @@ async def handler3(request):

@keep_alive_app_context.post("/ctx")
def set_ctx(request):
request.connection.ctx.foo = "hello"
request.conn_info.ctx.foo = "hello"
return text("OK")


@keep_alive_app_context.get("/ctx")
def get_ctx(request):
return text(request.connection.ctx.foo)
return text(request.conn_info.ctx.foo)


@pytest.mark.skipif(
Expand Down Expand Up @@ -256,14 +252,14 @@ def test_keep_alive_client_timeout():
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop)
headers = {"Connection": "keep-alive"}
request, response = client.get(
"/1", headers=headers, request_keepalive=1
)
_, response = client.get("/1", headers=headers, request_keepalive=1)

assert response.status == 200
assert response.text == "OK"

loop.run_until_complete(aio_sleep(2))
exception = None
request, response = client.get("/1", request_keepalive=1)
_, response = client.get("/1", request_keepalive=1)

assert ReusableSanicConnectionPool.last_reused_connection is None
finally:
client.kill_server()
Expand All @@ -283,14 +279,14 @@ def test_keep_alive_server_timeout():
asyncio.set_event_loop(loop)
client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop)
headers = {"Connection": "keep-alive"}
request, response = client.get(
"/1", headers=headers, request_keepalive=60
)
_, response = client.get("/1", headers=headers, request_keepalive=60)

assert response.status == 200
assert response.text == "OK"

loop.run_until_complete(aio_sleep(3))
exception = None
request, response = client.get("/1", request_keepalive=60)
_, response = client.get("/1", request_keepalive=60)

assert ReusableSanicConnectionPool.last_reused_connection is None
finally:
client.kill_server()
Expand All @@ -309,15 +305,12 @@ def test_keep_alive_connection_context():
request1, _ = client.post("/ctx", headers=headers)

loop.run_until_complete(aio_sleep(1))

request2, response = client.get("/ctx")

assert response.text == "hello"
assert id(request1.connection.ctx) == id(request2.connection.ctx)
assert id(request1.conn_info.ctx) == id(request2.conn_info.ctx)
assert (
request1.connection.ctx.foo
== request2.connection.ctx.foo
== "hello"
request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello"
)
finally:
client.kill_server()
16 changes: 16 additions & 0 deletions tests/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sanic import Sanic, response
from sanic.request import Request, uuid
from sanic.server import HttpProtocol


def test_no_request_id_not_called(monkeypatch):
Expand Down Expand Up @@ -83,3 +84,18 @@ async def get(request):

request, _ = app.test_client.get("/")
assert request.route is list(app.router.routes.values())[0]


def test_protocol_attribute(app):
retrieved = None

@app.get("/")
async def get(request):
nonlocal retrieved
retrieved = request.protocol
return response.empty()

headers = {"Connection": "keep-alive"}
_ = app.test_client.get("/", headers=headers)

assert isinstance(retrieved, HttpProtocol)

0 comments on commit 4e13e6e

Please sign in to comment.