Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for connection Upgrade and CONNECT when some data in the stream has been read. #882

Merged
merged 10 commits into from
Feb 20, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## Unreleased

- Support connection Upgrade and CONNECT. (#882)
MtkN1 marked this conversation as resolved.
Show resolved Hide resolved

## 1.0.3 (February 13th, 2024)

- Fix support for async cancellations. (#880)
Expand Down
50 changes: 47 additions & 3 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import enum
import logging
import ssl
import time
from types import TracebackType
from typing import (
Any,
AsyncIterable,
AsyncIterator,
List,
Expand Down Expand Up @@ -107,6 +109,7 @@ async def handle_async_request(self, request: Request) -> Response:
status,
reason_phrase,
headers,
trailing_data,
) = await self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
Expand All @@ -115,14 +118,22 @@ async def handle_async_request(self, request: Request) -> Response:
headers,
)

network_stream = self._network_stream

# CONNECT or Upgrade request
if (status == 101) or (
(request.method == b"CONNECT") and (200 <= status < 300)
):
network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data)

return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"network_stream": network_stream,
},
)
except BaseException as exc:
Expand Down Expand Up @@ -167,7 +178,7 @@ async def _send_event(

async def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand All @@ -187,7 +198,9 @@ async def _receive_response_headers(
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()

return http_version, event.status_code, event.reason, headers
trailing_data, _ = self._h11_state.trailing_data

return http_version, event.status_code, event.reason, headers, trailing_data

async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
timeouts = request.extensions.get("timeout", {})
Expand Down Expand Up @@ -340,3 +353,34 @@ async def aclose(self) -> None:
self._closed = True
async with Trace("response_closed", logger, self._request):
await self._connection._response_closed()


class AsyncHTTP11UpgradeStream(AsyncNetworkStream):
def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None:
self._stream = stream
self._leading_data = leading_data

async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._leading_data:
buffer = self._leading_data[:max_bytes]
self._leading_data = self._leading_data[max_bytes:]
return buffer
else:
return await self._stream.read(max_bytes, timeout)

async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
await self._stream.write(buffer, timeout)

async def aclose(self) -> None:
await self._stream.aclose()

async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
) -> AsyncNetworkStream:
return await self._stream.start_tls(ssl_context, server_hostname, timeout)

def get_extra_info(self, info: str) -> Any:
return self._stream.get_extra_info(info)
50 changes: 47 additions & 3 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import enum
import logging
import ssl
import time
from types import TracebackType
from typing import (
Any,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -107,6 +109,7 @@ def handle_request(self, request: Request) -> Response:
status,
reason_phrase,
headers,
trailing_data,
) = self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
Expand All @@ -115,14 +118,22 @@ def handle_request(self, request: Request) -> Response:
headers,
)

network_stream = self._network_stream

# CONNECT or Upgrade request
if (status == 101) or (
(request.method == b"CONNECT") and (200 <= status < 300)
):
network_stream = HTTP11UpgradeStream(network_stream, trailing_data)

return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"network_stream": network_stream,
},
)
except BaseException as exc:
Expand Down Expand Up @@ -167,7 +178,7 @@ def _send_event(

def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand All @@ -187,7 +198,9 @@ def _receive_response_headers(
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()

return http_version, event.status_code, event.reason, headers
trailing_data, _ = self._h11_state.trailing_data

return http_version, event.status_code, event.reason, headers, trailing_data

def _receive_response_body(self, request: Request) -> Iterator[bytes]:
timeouts = request.extensions.get("timeout", {})
Expand Down Expand Up @@ -340,3 +353,34 @@ def close(self) -> None:
self._closed = True
with Trace("response_closed", logger, self._request):
self._connection._response_closed()


class HTTP11UpgradeStream(NetworkStream):
def __init__(self, stream: NetworkStream, leading_data: bytes) -> None:
self._stream = stream
self._leading_data = leading_data

def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._leading_data:
buffer = self._leading_data[:max_bytes]
self._leading_data = self._leading_data[max_bytes:]
return buffer
else:
return self._stream.read(max_bytes, timeout)

def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
self._stream.write(buffer, timeout)

def close(self) -> None:
self._stream.close()

def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
) -> NetworkStream:
return self._stream.start_tls(ssl_context, server_hostname, timeout)

def get_extra_info(self, info: str) -> Any:
return self._stream.get_extra_info(info)
51 changes: 51 additions & 0 deletions tests/_async/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,57 @@ async def test_http11_upgrade_connection():
assert content == b"..."


@pytest.mark.anyio
async def test_http11_upgrade_with_trailing_data():
"""
HTTP "101 Switching Protocols" indicates an upgraded connection.

In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data
in the h11.Connection object.

https://h11.readthedocs.io/en/latest/api.html#switching-protocols
"""
origin = httpcore.Origin(b"wss", b"example.com", 443)
stream = httpcore.AsyncMockStream(
# The first element of this mock network stream buffer simulates networking
# in which response headers and data are received at once.
# This means that "foobar" becomes trailing data.
[
(
b"HTTP/1.1 101 Switching Protocols\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: custom\r\n"
b"\r\n"
b"foobar"
),
b"baz",
]
)
async with httpcore.AsyncHTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
async with conn.stream(
"GET",
"wss://example.com/",
headers={"Connection": "upgrade", "Upgrade": "custom"},
) as response:
assert response.status == 101
network_stream = response.extensions["network_stream"]

content = await network_stream.read(max_bytes=3)
assert content == b"foo"
content = await network_stream.read(max_bytes=3)
assert content == b"bar"
content = await network_stream.read(max_bytes=3)
assert content == b"baz"

# Lazy tests for HTTP11UpgradeStream
MtkN1 marked this conversation as resolved.
Show resolved Hide resolved
await network_stream.write(b"spam")
invalid = network_stream.get_extra_info("invalid")
assert invalid is None
await network_stream.aclose()


@pytest.mark.anyio
async def test_http11_early_hints():
"""
Expand Down
51 changes: 51 additions & 0 deletions tests/_sync/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,57 @@ def test_http11_upgrade_connection():



def test_http11_upgrade_with_trailing_data():
"""
HTTP "101 Switching Protocols" indicates an upgraded connection.

In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data
in the h11.Connection object.

https://h11.readthedocs.io/en/latest/api.html#switching-protocols
"""
origin = httpcore.Origin(b"wss", b"example.com", 443)
stream = httpcore.MockStream(
# The first element of this mock network stream buffer simulates networking
# in which response headers and data are received at once.
# This means that "foobar" becomes trailing data.
[
(
b"HTTP/1.1 101 Switching Protocols\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: custom\r\n"
b"\r\n"
b"foobar"
),
b"baz",
]
)
with httpcore.HTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
with conn.stream(
"GET",
"wss://example.com/",
headers={"Connection": "upgrade", "Upgrade": "custom"},
) as response:
assert response.status == 101
network_stream = response.extensions["network_stream"]

content = network_stream.read(max_bytes=3)
assert content == b"foo"
content = network_stream.read(max_bytes=3)
assert content == b"bar"
content = network_stream.read(max_bytes=3)
assert content == b"baz"

# Lazy tests for HTTP11UpgradeStream
network_stream.write(b"spam")
invalid = network_stream.get_extra_info("invalid")
assert invalid is None
network_stream.close()



def test_http11_early_hints():
"""
HTTP "103 Early Hints" is an interim response.
Expand Down
Loading