Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get rid of legacy class StreamWriter #2109 #2651

Merged
merged 11 commits into from
Jan 11, 2018
1 change: 1 addition & 0 deletions CHANGES/2651.removal
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Get rid of the legacy class StreamWriter.
9 changes: 6 additions & 3 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .http import WS_KEY, WebSocketReader, WebSocketWriter
from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse
from .streams import FlowControlDataQueue
from .tcp_helpers import tcp_cork, tcp_nodelay
from .tracing import Trace


Expand Down Expand Up @@ -296,7 +297,8 @@ async def _request(self, method, url, *,
'Connection timeout '
'to host {0}'.format(url)) from exc

conn.writer.set_tcp_nodelay(True)
tcp_nodelay(conn.transport, True)
tcp_cork(conn.transport, False)
try:
resp = req.send(conn)
try:
Expand Down Expand Up @@ -575,12 +577,13 @@ async def _ws_connect(self, url, *,
notakeover = False

proto = resp.connection.protocol
transport = resp.connection.transport
reader = FlowControlDataQueue(
proto, limit=2 ** 16, loop=self._loop)
proto.set_parser(WebSocketReader(reader), reader)
resp.connection.writer.set_tcp_nodelay(True)
tcp_nodelay(transport, True)
writer = WebSocketWriter(
resp.connection.writer, use_mask=True,
proto, transport, use_mask=True,
compress=compress, notakeover=notakeover)
except Exception:
resp.close()
Expand Down
6 changes: 2 additions & 4 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .client_exceptions import (ClientOSError, ClientPayloadError,
ServerDisconnectedError)
from .http import HttpResponseParser, StreamWriter
from .http import HttpResponseParser
from .streams import EMPTY_PAYLOAD, DataQueue


Expand All @@ -17,7 +17,6 @@ def __init__(self, *, loop=None):

self.paused = False
self.transport = None
self.writer = None
self._should_close = False

self._message = None
Expand Down Expand Up @@ -60,7 +59,6 @@ def is_connected(self):

def connection_made(self, transport):
self.transport = transport
self.writer = StreamWriter(self, transport, self._loop)

def connection_lost(self, exc):
if self._payload_parser is not None:
Expand All @@ -82,7 +80,7 @@ def connection_lost(self, exc):
exc = ServerDisconnectedError(uncompleted)
DataQueue.set_exception(self, exc)

self.transport = self.writer = None
self.transport = None
self._should_close = True
self._parser = None
self._message = None
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def send(self, conn):
if self.url.raw_query_string:
path += '?' + self.url.raw_query_string

writer = PayloadWriter(conn.writer, self.loop)
writer = PayloadWriter(conn.protocol, conn.transport, self.loop)

if self.compress:
writer.enable_compression(self.compress)
Expand Down
3 changes: 1 addition & 2 deletions aiohttp/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
WSCloseCode, WSMessage, WSMsgType, ws_ext_gen,
ws_ext_parse)
from .http_writer import (HttpVersion, HttpVersion10, HttpVersion11,
PayloadWriter, StreamWriter)
PayloadWriter)


__all__ = (
'HttpProcessingError', 'RESPONSES', 'SERVER_SOFTWARE',

# .http_writer
'PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11',
'StreamWriter',

# .http_parser
'HttpParser', 'HttpRequestParser', 'HttpResponseParser',
Expand Down
16 changes: 8 additions & 8 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,11 +513,11 @@ def parse_frame(self, buf):

class WebSocketWriter:

def __init__(self, stream, *,
def __init__(self, protocol, transport, *,
use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(),
compress=0, notakeover=False):
self.stream = stream
self.writer = stream.transport
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
self.randrange = random.randrange
self.compress = compress
Expand Down Expand Up @@ -572,20 +572,20 @@ def _send_frame(self, message, opcode, compress=None):
mask = mask.to_bytes(4, 'big')
message = bytearray(message)
_websocket_mask(mask, message)
self.writer.write(header + mask + message)
self.transport.write(header + mask + message)
self._output_size += len(header) + len(mask) + len(message)
else:
if len(message) > MSG_SIZE:
self.writer.write(header)
self.writer.write(message)
self.transport.write(header)
self.transport.write(message)
else:
self.writer.write(header + message)
self.transport.write(header + message)

self._output_size += len(header) + len(message)

if self._output_size > self._limit:
self._output_size = 0
return self.stream.drain()
return self.protocol._drain_helper()

return noop()

Expand Down
104 changes: 18 additions & 86 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,104 +2,24 @@

import asyncio
import collections
import socket
import zlib
from contextlib import suppress

from .abc import AbstractPayloadWriter
from .helpers import noop


__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11',
'StreamWriter')
__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11')

HttpVersion = collections.namedtuple('HttpVersion', ['major', 'minor'])
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)


if hasattr(socket, 'TCP_CORK'): # pragma: no cover
CORK = socket.TCP_CORK
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
CORK = socket.TCP_NOPUSH
else: # pragma: no cover
CORK = None


class StreamWriter:
class PayloadWriter(AbstractPayloadWriter):

def __init__(self, protocol, transport, loop):
self._protocol = protocol
self._loop = loop
self._tcp_nodelay = False
self._tcp_cork = False
self._socket = transport.get_extra_info('socket')
self._waiters = []
self.transport = transport

@property
def tcp_nodelay(self):
return self._tcp_nodelay

def set_tcp_nodelay(self, value):
value = bool(value)
if self._tcp_nodelay == value:
return
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return

# socket may be closed already, on windows OSError get raised
with suppress(OSError):
if self._tcp_cork:
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False)
self._tcp_cork = False

self._socket.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
self._tcp_nodelay = value

@property
def tcp_cork(self):
return self._tcp_cork

def set_tcp_cork(self, value):
value = bool(value)
if self._tcp_cork == value:
return
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return

with suppress(OSError):
if self._tcp_nodelay:
self._socket.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, False)
self._tcp_nodelay = False
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value)
self._tcp_cork = value

async def drain(self):
"""Flush the write buffer.

The intended use is to write

await w.write(data)
await w.drain()
"""
if self._protocol.transport is not None:
await self._protocol._drain_helper()


class PayloadWriter(AbstractPayloadWriter):

def __init__(self, stream, loop):
self._stream = stream
self._transport = None
self._transport = transport

self.loop = loop
self.length = None
Expand All @@ -110,11 +30,15 @@ def __init__(self, stream, loop):
self._eof = False
self._compress = None
self._drain_waiter = None
self._transport = self._stream.transport

async def get_transport(self):
@property
def transport(self):
return self._transport

@property
def protocol(self):
return self._protocol
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test for the property

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def enable_chunking(self):
self.chunked = True

Expand Down Expand Up @@ -204,4 +128,12 @@ async def write_eof(self, chunk=b''):
self._transport = None

async def drain(self):
await self._stream.drain()
"""Flush the write buffer.

The intended use is to write

await w.write(data)
await w.drain()
"""
if self._protocol.transport is not None:
await self._protocol._drain_helper()
67 changes: 67 additions & 0 deletions aiohttp/tcp_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Helper methods to tune a TCP connection"""

import socket
from contextlib import suppress


__all__ = ('tcp_keepalive', 'tcp_nodelay', 'tcp_cork')


if hasattr(socket, 'TCP_CORK'): # pragma: no cover
CORK = socket.TCP_CORK
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
CORK = socket.TCP_NOPUSH
else: # pragma: no cover
CORK = None


if hasattr(socket, 'SO_KEEPALIVE'):
def tcp_keepalive(transport):
sock = transport.get_extra_info('socket')
if sock is not None:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
else:
def tcp_keepalive(transport): # pragma: no cover
pass


def tcp_nodelay(transport, value):
sock = transport.get_extra_info('socket')

if sock is None:
return None

if sock.family not in (socket.AF_INET, socket.AF_INET6):
return None

value = bool(value)

# socket may be closed already, on windows OSError get raised
with suppress(OSError):
sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
return value

return None


def tcp_cork(transport, value):
sock = transport.get_extra_info('socket')

if CORK is None:
return None

if sock is None:
return None

if sock.family not in (socket.AF_INET, socket.AF_INET6):
return None

value = bool(value)

with suppress(OSError):
sock.setsockopt(
socket.IPPROTO_TCP, CORK, value)
return value

return None
9 changes: 4 additions & 5 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def _sendfile_cb(self, fut, out_fd, in_fd,
set_result(fut, None)

async def sendfile(self, fobj, count):
transport = await self.get_transport()

out_socket = transport.get_extra_info('socket').dup()
out_socket = self.transport.get_extra_info('socket').dup()
out_socket.setblocking(False)
out_fd = out_socket.fileno()
in_fd = fobj.fileno()
Expand All @@ -71,7 +69,7 @@ async def sendfile(self, fobj, count):
await fut
except Exception:
server_logger.debug('Socket error')
transport.close()
self.transport.close()
finally:
out_socket.close()

Expand Down Expand Up @@ -112,7 +110,8 @@ async def _sendfile_system(self, request, fobj, count):
writer = await self._sendfile_fallback(request, fobj, count)
else:
writer = SendfilePayloadWriter(
request._protocol.writer,
request.protocol,
transport,
request.loop
)
request._payload_writer = writer
Expand Down
Loading