diff --git a/CHANGES/9649.feature.rst b/CHANGES/9649.feature.rst new file mode 120000 index 00000000000..a93584bccd8 --- /dev/null +++ b/CHANGES/9649.feature.rst @@ -0,0 +1 @@ +9543.feature.rst \ No newline at end of file diff --git a/aiohttp/_websocket/reader_c.pxd b/aiohttp/_websocket/reader_c.pxd index af26d350db3..2a60f327061 100644 --- a/aiohttp/_websocket/reader_c.pxd +++ b/aiohttp/_websocket/reader_c.pxd @@ -45,7 +45,8 @@ cdef class WebSocketReader: cdef object _opcode cdef object _frame_fin cdef object _frame_opcode - cdef bytearray _frame_payload + cdef object _frame_payload + cdef unsigned int _frame_payload_len cdef bytes _tail cdef bint _has_mask @@ -74,9 +75,9 @@ cdef class WebSocketReader: chunk_size="unsigned int", chunk_len="unsigned int", buf_length="unsigned int", - payload=bytearray, first_byte="unsigned char", second_byte="unsigned char", + end_pos="unsigned int", has_mask=bint, fin=bint, ) diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index 2c77cde4c72..0910a340629 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -55,7 +55,8 @@ def __init__( self._opcode: Optional[int] = None self._frame_fin = False self._frame_opcode: Optional[int] = None - self._frame_payload = bytearray() + self._frame_payload: Union[bytes, bytearray] = b"" + self._frame_payload_len = 0 self._tail: bytes = b"" self._has_mask = False @@ -133,6 +134,7 @@ def _feed_data(self, data: bytes) -> None: "to be zero, got {!r}".format(opcode), ) + assembled_payload: Union[bytes, bytearray] if has_partial: assembled_payload = self._partial + payload self._partial.clear() @@ -165,6 +167,8 @@ def _feed_data(self, data: bytes) -> None: self._max_msg_size + left, self._max_msg_size ), ) + elif type(assembled_payload) is bytes: + payload_merged = assembled_payload else: payload_merged = bytes(assembled_payload) @@ -229,9 +233,11 @@ def _feed_data(self, data: bytes) -> None: def parse_frame( self, buf: bytes - ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]: + ) -> List[Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]]: """Return the next frame from the socket.""" - frames: List[Tuple[bool, Optional[int], bytearray, Optional[bool]]] = [] + frames: List[ + Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]] + ] = [] if self._tail: buf, self._tail = self._tail + buf, b"" @@ -333,30 +339,44 @@ def parse_frame( self._state = READ_PAYLOAD if self._state == READ_PAYLOAD: - length = self._payload_length - payload = self._frame_payload - chunk_len = buf_length - start_pos - if length >= chunk_len: - self._payload_length = length - chunk_len - payload += buf[start_pos:] - start_pos = buf_length + if self._payload_length >= chunk_len: + end_pos = buf_length + self._payload_length -= chunk_len else: + end_pos = start_pos + self._payload_length self._payload_length = 0 - payload += buf[start_pos : start_pos + length] - start_pos = start_pos + length + + if self._frame_payload_len: + if type(self._frame_payload) is not bytearray: + self._frame_payload = bytearray(self._frame_payload) + self._frame_payload += buf[start_pos:end_pos] + else: + # Fast path for the first frame + self._frame_payload = buf[start_pos:end_pos] + + self._frame_payload_len += end_pos - start_pos + start_pos = end_pos if self._payload_length != 0: break if self._has_mask: assert self._frame_mask is not None - websocket_mask(self._frame_mask, payload) + if type(self._frame_payload) is not bytearray: + self._frame_payload = bytearray(self._frame_payload) + websocket_mask(self._frame_mask, self._frame_payload) frames.append( - (self._frame_fin, self._frame_opcode, payload, self._compressed) + ( + self._frame_fin, + self._frame_opcode, + self._frame_payload, + self._compressed, + ) ) - self._frame_payload = bytearray() + self._frame_payload = b"" + self._frame_payload_len = 0 self._state = READ_HEADER self._tail = buf[start_pos:] if start_pos < buf_length else b"" diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index abddeadf5a1..d034245af7c 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -139,6 +139,20 @@ def test_parse_frame_length2_multi_byte(parser: WebSocketReader) -> None: assert (0, 1, expected_payload, False) == (fin, opcode, payload, not not compress) +def test_parse_frame_length2_multi_byte_multi_packet(parser: WebSocketReader) -> None: + """Ensure a multi-byte length with multiple packets is parsed correctly.""" + expected_payload = b"1" * 32768 + assert parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) == [] + assert parser.parse_frame(struct.pack("!H", 32768)) == [] + assert parser.parse_frame(b"1" * 8192) == [] + assert parser.parse_frame(b"1" * 8192) == [] + assert parser.parse_frame(b"1" * 8192) == [] + res = parser.parse_frame(b"1" * 8192) + fin, opcode, payload, compress = res[0] + assert len(payload) == 32768 + assert (0, 1, expected_payload, False) == (fin, opcode, payload, not not compress) + + def test_parse_frame_length4(parser: WebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b00000001, 127)) parser.parse_frame(struct.pack("!Q", 4))