Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #9649/f2f5b056 backport][3.11] Avoid memory copy in the WebSocket reader for small payloads #9650

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/9649.feature.rst
5 changes: 3 additions & 2 deletions aiohttp/_websocket/reader_c.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ cdef class WebSocketReader:
cdef object _opcode
cdef object _frame_fin
cdef object _frame_opcode
cdef bytearray _frame_payload
cdef object _frame_payload
cdef unsigned int _frame_payload_len

cdef bytes _tail
cdef bint _has_mask
Expand Down Expand Up @@ -74,9 +75,9 @@ cdef class WebSocketReader:
chunk_size="unsigned int",
chunk_len="unsigned int",
buf_length="unsigned int",
payload=bytearray,
first_byte="unsigned char",
second_byte="unsigned char",
end_pos="unsigned int",
has_mask=bint,
fin=bint,
)
Expand Down
50 changes: 35 additions & 15 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(
self._opcode: Optional[int] = None
self._frame_fin = False
self._frame_opcode: Optional[int] = None
self._frame_payload = bytearray()
self._frame_payload: Union[bytes, bytearray] = b""
self._frame_payload_len = 0

self._tail: bytes = b""
self._has_mask = False
Expand Down Expand Up @@ -133,6 +134,7 @@ def _feed_data(self, data: bytes) -> None:
"to be zero, got {!r}".format(opcode),
)

assembled_payload: Union[bytes, bytearray]
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
Expand Down Expand Up @@ -165,6 +167,8 @@ def _feed_data(self, data: bytes) -> None:
self._max_msg_size + left, self._max_msg_size
),
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
else:
payload_merged = bytes(assembled_payload)

Expand Down Expand Up @@ -229,9 +233,11 @@ def _feed_data(self, data: bytes) -> None:

def parse_frame(
self, buf: bytes
) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]:
) -> List[Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]]:
"""Return the next frame from the socket."""
frames: List[Tuple[bool, Optional[int], bytearray, Optional[bool]]] = []
frames: List[
Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]
] = []
if self._tail:
buf, self._tail = self._tail + buf, b""

Expand Down Expand Up @@ -333,30 +339,44 @@ def parse_frame(
self._state = READ_PAYLOAD

if self._state == READ_PAYLOAD:
length = self._payload_length
payload = self._frame_payload

chunk_len = buf_length - start_pos
if length >= chunk_len:
self._payload_length = length - chunk_len
payload += buf[start_pos:]
start_pos = buf_length
if self._payload_length >= chunk_len:
end_pos = buf_length
self._payload_length -= chunk_len
else:
end_pos = start_pos + self._payload_length
self._payload_length = 0
payload += buf[start_pos : start_pos + length]
start_pos = start_pos + length

if self._frame_payload_len:
if type(self._frame_payload) is not bytearray:
self._frame_payload = bytearray(self._frame_payload)
self._frame_payload += buf[start_pos:end_pos]
else:
# Fast path for the first frame
self._frame_payload = buf[start_pos:end_pos]

self._frame_payload_len += end_pos - start_pos
start_pos = end_pos

if self._payload_length != 0:
break

if self._has_mask:
assert self._frame_mask is not None
websocket_mask(self._frame_mask, payload)
if type(self._frame_payload) is not bytearray:
self._frame_payload = bytearray(self._frame_payload)
websocket_mask(self._frame_mask, self._frame_payload)

frames.append(
(self._frame_fin, self._frame_opcode, payload, self._compressed)
(
self._frame_fin,
self._frame_opcode,
self._frame_payload,
self._compressed,
)
)
self._frame_payload = bytearray()
self._frame_payload = b""
self._frame_payload_len = 0
self._state = READ_HEADER

self._tail = buf[start_pos:] if start_pos < buf_length else b""
Expand Down
14 changes: 14 additions & 0 deletions tests/test_websocket_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,20 @@ def test_parse_frame_length2_multi_byte(parser: WebSocketReader) -> None:
assert (0, 1, expected_payload, False) == (fin, opcode, payload, not not compress)


def test_parse_frame_length2_multi_byte_multi_packet(parser: WebSocketReader) -> None:
"""Ensure a multi-byte length with multiple packets is parsed correctly."""
expected_payload = b"1" * 32768
assert parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) == []
assert parser.parse_frame(struct.pack("!H", 32768)) == []
assert parser.parse_frame(b"1" * 8192) == []
assert parser.parse_frame(b"1" * 8192) == []
assert parser.parse_frame(b"1" * 8192) == []
res = parser.parse_frame(b"1" * 8192)
fin, opcode, payload, compress = res[0]
assert len(payload) == 32768
assert (0, 1, expected_payload, False) == (fin, opcode, payload, not not compress)


def test_parse_frame_length4(parser: WebSocketReader) -> None:
parser.parse_frame(struct.pack("!BB", 0b00000001, 127))
parser.parse_frame(struct.pack("!Q", 4))
Expand Down
Loading