Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Nov 3, 2024
1 parent 51e9841 commit 57649f7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
33 changes: 18 additions & 15 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,38 +342,41 @@ def parse_frame(
self._state = READ_PAYLOAD

if self._state == READ_PAYLOAD:
length = self._payload_length
payload = self._frame_payload

chunk_len = buf_length - start_pos
if length >= chunk_len:
self._payload_length = length - chunk_len
if self._payload_length >= chunk_len:
end_pos = buf_length
self._payload_length -= chunk_len
else:
end_pos = start_pos + self._payload_length
self._payload_length = 0
end_pos = start_pos + length

if self._frame_payload_len:
if type(payload) is not bytearray:
payload = bytearray(payload)
payload += buf[start_pos:end_pos]
if type(self._frame_payload) is not bytearray:
self._frame_payload = bytearray(self._frame_payload)
self._frame_payload += buf[start_pos:end_pos]
else:
# Fast path for the first frame
payload = buf[start_pos:end_pos]
self._frame_payload = buf[start_pos:end_pos]

self._frame_payload_len += end_pos - start_pos
start_pos = end_pos

self._frame_payload_len += length
if self._payload_length != 0:
break

if self._has_mask:
assert self._frame_mask is not None
if type(payload) is not bytearray:
payload = bytearray(payload)
websocket_mask(self._frame_mask, payload)
if type(self._frame_payload) is not bytearray:
self._frame_payload = bytearray(self._frame_payload)
websocket_mask(self._frame_mask, self._frame_payload)

frames.append(
(self._frame_fin, self._frame_opcode, payload, self._compressed)
(
self._frame_fin,
self._frame_opcode,
self._frame_payload,
self._compressed,
)
)
self._frame_payload = b""
self._frame_payload_len = 0
Expand Down
14 changes: 14 additions & 0 deletions tests/test_websocket_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ def test_parse_frame_length2_multi_byte(parser: WebSocketReader) -> None:
assert (0, 1, expected_payload, False) == (fin, opcode, payload, not not compress)


def test_parse_frame_length2_multi_byte_multi_packet(parser: WebSocketReader) -> None:
"""Ensure a multi-byte length with multiple packets is parsed correctly."""
expected_payload = b"1" * 32768
assert parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) == []
assert parser.parse_frame(struct.pack("!H", 32768)) == []
assert parser.parse_frame(b"1" * 8192) == []
assert parser.parse_frame(b"1" * 8192) == []
assert parser.parse_frame(b"1" * 8192) == []
res = parser.parse_frame(b"1" * 8192)
fin, opcode, payload, compress = res[0]
assert len(payload) == 32768
assert (0, 1, expected_payload, False) == (fin, opcode, payload, not not compress)


def test_parse_frame_length4(parser: WebSocketReader) -> None:
parser.parse_frame(struct.pack("!BB", 0b00000001, 127))
parser.parse_frame(struct.pack("!Q", 4))
Expand Down

0 comments on commit 57649f7

Please sign in to comment.