Skip to content

Commit

Permalink
Merge pull request #1962 from pjknkda/master
Browse files Browse the repository at this point in the history
fix issues caused by websocket frame fragmentation
  • Loading branch information
fafhrd91 authored Jun 8, 2017
2 parents 6a902ff + 03984c2 commit 6b85bcd
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 39 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Changes

-

-
- Fix websocket issues caused by frame fragmentation. #1962

-

Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Joongi Kim
Josep Cugat
Julia Tsemusheva
Julien Duponchelle
Jungkook Park
Junjie Tao
Justas Trimailovas
Justin Turner Arthur
Expand Down
33 changes: 12 additions & 21 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,17 @@ def _feed_data(self, data):
WSMessage(WSMsgType.PONG, payload, ''), len(payload))

elif opcode not in (
WSMsgType.TEXT, WSMsgType.BINARY) and not self._opcode:
WSMsgType.TEXT, WSMsgType.BINARY) and self._opcode is None:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Unexpected opcode={!r}".format(opcode))
else:
# load text/binary

if not fin:
# got partial frame payload
if opcode != WSMsgType.CONTINUATION:
self._opcode = opcode

self._partial.append(payload)
else:
# previous frame was non finished
Expand All @@ -248,39 +248,37 @@ def _feed_data(self, data):

if opcode == WSMsgType.CONTINUATION:
opcode = self._opcode
self._opcode = None

self._partial.append(payload)
payload_merged = b''.join(self._partial) + payload
self._partial.clear()

if opcode == WSMsgType.TEXT:
try:
text = b''.join(self._partial).decode('utf-8')
text = payload_merged.decode('utf-8')
self.queue.feed_data(
WSMessage(WSMsgType.TEXT, text, ''), len(text))
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
else:
data = b''.join(self._partial)
self.queue.feed_data(
WSMessage(WSMsgType.BINARY, data, ''), len(data))

self._start_opcode = None
self._partial.clear()
WSMessage(WSMsgType.BINARY, payload_merged, ''),
len(payload_merged))

return False, b''

def parse_frame(self, buf, continuation=False, EMPTY=b''):
def parse_frame(self, buf):
"""Return the next frame from the socket."""
frames = []
if self._tail:
buf, self._tail = self._tail + buf, EMPTY
buf, self._tail = self._tail + buf, b''

start_pos = 0
buf_length = len(buf)

while True:

# read header
if self._state == WSParserState.READ_HEADER:
if buf_length - start_pos >= 2:
Expand Down Expand Up @@ -312,15 +310,6 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''):
WSCloseCode.PROTOCOL_ERROR,
'Received fragmented control frame')

continuation = not self._frame_fin
if (fin == 0 and
opcode == WSMsgType.CONTINUATION and
not continuation):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received new fragment frame with non-zero '
'opcode {!r}'.format(opcode))

has_mask = (second_byte >> 7) & 1
length = (second_byte) & 0x7f

Expand Down Expand Up @@ -409,6 +398,8 @@ def parse_frame(self, buf, continuation=False, EMPTY=b''):
else:
break

self._tail = buf[start_pos:]

return frames


Expand Down
48 changes: 31 additions & 17 deletions tests/test_websocket_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,28 @@
_websocket_mask)


def build_frame(message, opcode, use_mask=False, noheader=False):
def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True):
"""Send a frame over the websocket with message as its payload."""
msg_length = len(message)
if use_mask: # pragma: no cover
mask_bit = 0x80
else:
mask_bit = 0

if is_fin:
header_first_byte = 0x80 | opcode
else:
header_first_byte = opcode

if msg_length < 126:
header = PACK_LEN1(
0x80 | opcode, msg_length | mask_bit)
header_first_byte, msg_length | mask_bit)
elif msg_length < (1 << 16): # pragma: no cover
header = PACK_LEN2(
0x80 | opcode, 126 | mask_bit, msg_length)
header_first_byte, 126 | mask_bit, msg_length)
else:
header = PACK_LEN3(
0x80 | opcode, 127 | mask_bit, msg_length)
header_first_byte, 127 | mask_bit, msg_length)

if use_mask: # pragma: no cover
mask = random.randrange(0, 0xffffffff)
Expand Down Expand Up @@ -117,13 +122,6 @@ def test_parse_frame_header_control_frame(out, parser):
raise out.exception()


def test_parse_frame_header_continuation(out, parser):
with pytest.raises(WebSocketError):
parser._frame_fin = True
parser.parse_frame(struct.pack('!BB', 0b00000000, 0b00000000))
raise out.exception()


def _test_parse_frame_header_new_data_err(out, parser):
with pytest.raises(WebSocketError):
parser.parse_frame(struct.pack('!BB', 0b000000000, 0b00000000))
Expand Down Expand Up @@ -234,13 +232,21 @@ def test_simple_binary(out, parser):
assert res == ((WSMsgType.BINARY, b'binary', ''), 6)


def test_fragmentation_header(out, parser):
data = build_frame(b'a', WSMsgType.TEXT)
parser._feed_data(data[:1])
parser._feed_data(data[1:])

res = out._buffer[0]
assert res == (WSMessage(WSMsgType.TEXT, 'a', ''), 1)


def test_continuation(out, parser):
parser.parse_frame = mock.Mock()
parser.parse_frame.return_value = [
(0, WSMsgType.TEXT, b'line1'),
(1, WSMsgType.CONTINUATION, b'line2')]
data1 = build_frame(b'line1', WSMsgType.TEXT, is_fin=False)
parser._feed_data(data1)

parser._feed_data(b'')
data2 = build_frame(b'line2', WSMsgType.CONTINUATION)
parser._feed_data(data2)

res = out._buffer[0]
assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10)
Expand All @@ -254,7 +260,15 @@ def test_continuation_with_ping(out, parser):
(1, WSMsgType.CONTINUATION, b'line2'),
]

parser.feed_data(b'')
data1 = build_frame(b'line1', WSMsgType.TEXT, is_fin=False)
parser._feed_data(data1)

data2 = build_frame(b'', WSMsgType.PING)
parser._feed_data(data2)

data3 = build_frame(b'line2', WSMsgType.CONTINUATION)
parser._feed_data(data3)

res = out._buffer[0]
assert res == (WSMessage(WSMsgType.PING, b'', ''), 0)
res = out._buffer[1]
Expand Down

0 comments on commit 6b85bcd

Please sign in to comment.