Skip to content

Commit

Permalink
refactor client protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikolay Kim committed Feb 13, 2017
1 parent 9189924 commit ee234db
Show file tree
Hide file tree
Showing 13 changed files with 309 additions and 121 deletions.
11 changes: 11 additions & 0 deletions aiohttp/_ws_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ class WebSocketReader:
def __init__(self, queue):
self.queue = queue

self._exc = None
self._partial = []
self._state = WSParserState.READ_HEADER

Expand All @@ -333,7 +334,17 @@ def feed_eof(self):
self.queue.feed_eof()

def feed_data(self, data):
if self._exc:
return True, data

try:
return self._feed_data(data)
except Exception as exc:
self._exc = exc
self.queue.set_exception(exc)
return True, b''

def _feed_data(self, data):
for fin, opcode, payload in self.parse_frame(data):

if opcode == WSMsgType.CLOSE:
Expand Down
10 changes: 7 additions & 3 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
import aiohttp

from . import hdrs, helpers
from ._ws_impl import WS_KEY, WebSocketParser, WebSocketWriter
from ._ws_impl import WS_KEY, WebSocketReader, WebSocketWriter
from .client_reqrep import ClientRequest, ClientResponse
from .client_ws import ClientWebSocketResponse
from .cookiejar import CookieJar
from .errors import WSServerHandshakeError
from .helpers import TimeService
from .streams import FlowControlDataQueue

__all__ = ('ClientSession', 'request')

Expand Down Expand Up @@ -218,7 +219,7 @@ def _request(self, method, url, *,
conn = yield from self._connector.connect(req)
conn.writer.set_tcp_nodelay(True)
try:
resp = req.send(conn.writer, conn.reader)
resp = req.send(conn)
try:
yield from resp.start(conn, read_until_eof)
except:
Expand Down Expand Up @@ -391,7 +392,10 @@ def _ws_connect(self, url, *,
protocol = proto
break

reader = resp.connection.reader.set_parser(WebSocketParser)
proto = resp.connection.protocol
reader = FlowControlDataQueue(
proto, limit=2 ** 16, loop=self._loop)
proto.set_parser(WebSocketReader(reader), reader)
resp.connection.writer.set_tcp_nodelay(True)
writer = WebSocketWriter(resp.connection.writer, use_mask=True)
except Exception:
Expand Down
188 changes: 188 additions & 0 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import asyncio
import asyncio.streams
import socket

from . import errors, hdrs, streams
from .errors import ServerDisconnectedError
from .streams import DataQueue, FlowControlStreamReader, EmptyStreamReader
from .parsers import StreamParser, StreamWriter
from .protocol import HttpResponseParser, HttpPayloadParser

EMPTY_PAYLOAD = EmptyStreamReader()


class HttpClientProtocol(DataQueue, asyncio.streams.FlowControlMixin):
"""Helper class to adapt between Protocol and StreamReader."""

def __init__(self, *, loop=None, **kwargs):
asyncio.streams.FlowControlMixin.__init__(self, loop=loop)
DataQueue.__init__(self, loop=loop)

self.paused = False
self.transport = None
self.writer = None
self._should_close = False

self._payload = None
self._payload_parser = None

self._timer = None
self._skip_status = ()

self._lines = []
self._tail = b''
self._upgrade = False
self._response_parser = HttpResponseParser()

@property
def should_close(self):
return (self._should_close or self._upgrade or
self.exception() is not None or
self._payload is not None or
self._payload_parser is not None or
self._lines or self._tail)

def is_connected(self):
return self.transport is not None

def connection_made(self, transport):
self.transport = transport
self.writer = StreamWriter(transport, self, None, self._loop)

def connection_lost(self, exc):
self.transport = self.writer = None

if exc is None:
exc = ServerDisconnectedError()

if self._payload is not None:
self._payload.set_exception(exc)
DataQueue.set_exception(self, exc)

super().connection_lost(exc)

def eof_received(self):
pass

def set_exception(self, exc):
self._should_close = True

super().set_exception(exc)

def set_parser(self, parser, payload):
self._payload = payload
self._payload_parser = parser

if self._tail:
data, self._tail = self._tail, None
self.data_received(data)

def set_response_params(self, *, timer=None,
skip_payload=False,
skip_status_codes=(),
read_until_eof=False):
self._timer = timer
self._skip_payload = skip_payload
self._skip_status_codes = skip_status_codes
self._read_until_eof = read_until_eof

def data_received(self, data,
SEP=b'\r\n',
CONTENT_LENGTH=hdrs.CONTENT_LENGTH,
SEC_WEBSOCKET_KEY1=hdrs.SEC_WEBSOCKET_KEY1):

# feed payload
if self._payload_parser is not None:
assert not self._lines
if data:
eof, tail = self._payload_parser.feed_data(data)
if eof:
self._payload = None
self._payload_parser = None

if tail:
super().data_received(tail)

return

# read HTTP message (status line + headers), \r\n\r\n
# and split by lines
if self._tail:
data = self._tail + data

start_pos = 0
while True:
pos = data.find(SEP, start_pos)
if pos >= start_pos:
# line found
self._lines.append(data[start_pos:pos])

# \r\n\r\n found
start_pos = pos + 2
if data[start_pos:start_pos+2] == SEP:
self._lines.append(b'')

msg = None
try:
msg = self._response_parser.parse_message(self._lines)

# payload length
length = msg.headers.get(CONTENT_LENGTH)
if length is not None:
try:
length = int(length)
except ValueError:
raise errors.InvalidHeader(CONTENT_LENGTH)
if length < 0:
raise errors.InvalidHeader(CONTENT_LENGTH)

# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in msg.headers:
raise errors.InvalidHeader(SEC_WEBSOCKET_KEY1)
except:
self._should_close = True
raise
else:
self._lines.clear()

self._should_close = msg.should_close

# calculate payload
empty_payload = True
if (((length is not None and length > 0) or msg.chunked) and
(not self._skip_payload and
msg.code not in self._skip_status_codes)):

if not msg.upgrade:
payload = FlowControlStreamReader(
self, timer=self._timer, loop=self._loop)
payload_parser = HttpPayloadParser(
msg, readall=self._read_until_eof)

if payload_parser.start(length, payload):
empty_payload = False
self._payload = payload
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD
else:
payload = EMPTY_PAYLOAD

self._upgrade = msg.upgrade

self.feed_data((msg, payload), 0)

start_pos = start_pos + 2
if start_pos < len(data):
if self._upgrade:
self._tail = data[start_pos:]
return
if empty_payload:
continue

self._tail = None
self.data_received(data[start_pos:])
return
else:
self._tail = data[start_pos:]
return
45 changes: 17 additions & 28 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def update_proxy(self, proxy, proxy_auth):
self.proxy_auth = proxy_auth

@asyncio.coroutine
def write_bytes(self, request, reader):
def write_bytes(self, request, conn):
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
Expand Down Expand Up @@ -417,7 +417,7 @@ def write_bytes(self, request, reader):
'Can not write request body for %s' % self.url)
new_exc.__context__ = exc
new_exc.__cause__ = exc
reader.set_exception(new_exc)
conn.protocol.set_exception(new_exc)
else:
try:
ret = yield from request.write_eof()
Expand All @@ -431,11 +431,11 @@ def write_bytes(self, request, reader):
'Can not write request body for %s' % self.url)
new_exc.__context__ = exc
new_exc.__cause__ = exc
reader.set_exception(new_exc)
conn.protocol.set_exception(new_exc)

self._writer = None

def send(self, writer, reader):
def send(self, conn):
# Specify request target:
# - CONNECT request must send authority form URI
# - not CONNECT proxy must send absolute form URI
Expand All @@ -450,7 +450,7 @@ def send(self, writer, reader):
path += '?' + self.url.raw_query_string

request = aiohttp.Request(
writer, self.method, path, self.version, loop=self.loop)
conn.writer, self.method, path, self.version, loop=self.loop)

if self.compress:
request.enable_compression(self.compress)
Expand All @@ -469,7 +469,7 @@ def send(self, writer, reader):
request.send_headers()

self._writer = helpers.ensure_future(
self.write_bytes(request, reader), loop=self.loop)
self.write_bytes(request, conn), loop=self.loop)

self.response = self.response_class(
self.method, self.original_url,
Expand Down Expand Up @@ -591,27 +591,21 @@ def history(self):
"""A sequence of of responses, if redirects occurred."""
return self._history

def _setup_connection(self, connection):
self._reader = connection.reader
self._connection = connection
self.content = self.flow_control_class(
connection.reader, loop=connection.loop, timer=self._timer)

def _need_parse_response_body(self):
return (self.method.lower() != 'head' and
self.status not in [204, 304])

@asyncio.coroutine
def start(self, connection, read_until_eof=False):
"""Start response processing."""
self._setup_connection(connection)
self._protocol = connection.protocol
self._connection = connection
connection.protocol.set_response_params(
timer=self._timer,
skip_payload=self.method.lower() == 'head',
skip_status_codes=(204, 304),
read_until_eof=read_until_eof)

with self._timer:
while True:
httpstream = self._reader.set_parser(self._response_parser)

# read response
message = yield from httpstream.read()
(message, payload) = yield from self._protocol.read()
if (message.code < 100 or
message.code > 199 or message.code == 101):
break
Expand All @@ -631,12 +625,7 @@ def start(self, connection, read_until_eof=False):
self.raw_headers = tuple(message.raw_headers)

# payload
rwb = self._need_parse_response_body()
self._reader.set_parser(
aiohttp.HttpPayloadParser(message,
readall=read_until_eof,
response_with_body=rwb),
self.content)
self.content = payload

# cookies
for hdr in self.headers.getall(hdrs.SET_COOKIE, ()):
Expand Down Expand Up @@ -690,8 +679,8 @@ def release(self, *, consume=False):
self._closed = True
if self._connection is not None:
self._connection.release()
if self._reader is not None:
self._reader.unset_parser()
#if self._reader is not None:
#self._reader.unset_parser()
self._connection = None
self._cleanup_writer()
self._notify_content()
Expand Down
Loading

0 comments on commit ee234db

Please sign in to comment.