From f570fed1086b0181d51cdc4cea6df842519ef196 Mon Sep 17 00:00:00 2001
From: Pau Freixes <pfreixes@gmail.com>
Date: Thu, 11 Jan 2018 10:37:20 +0100
Subject: [PATCH] Get rid of legacy class StreamWriter #2109 (#2651)

* Get rid of legacy StreamWriter (#2623)

Legacy StreamWriter as a pure proxy of the transport and the protocol is
no longer needed. All of the functionalities that were behind this class
has been moved to the PayloadWriter.

Some changes that have to be considered that impacted during this change
* TCP Operations have been isolated in a module rather than move them
into the PayloadWriter
* WebSocketWriter had a dependency with the StreamWriter, to get rid of
that dependency the constructor has been modified to take the protocol
and the transport.

A next step changing the name PayLoadWriter for the StreamWriter to have
consistency with the reader part, might be considered.

* Add CHANGES

* Fixed invalid import order

* Fix test broken

* Fix tcp_cork issues

* Test PayloadWriter properties

* Avoid return useless values for tcp_<operations>

* Increase coverage http_writer

* Increase coverage web_protocol
---
 CHANGES/2651.removal               |   1 +
 aiohttp/client.py                  |   9 +-
 aiohttp/client_proto.py            |   6 +-
 aiohttp/client_reqrep.py           |   2 +-
 aiohttp/http.py                    |   3 +-
 aiohttp/http_websocket.py          |  16 +-
 aiohttp/http_writer.py             | 104 ++----------
 aiohttp/tcp_helpers.py             |  61 +++++++
 aiohttp/web_fileresponse.py        |   9 +-
 aiohttp/web_protocol.py            |  29 +---
 aiohttp/web_ws.py                  |   3 +-
 tests/test_client_request.py       |  30 ++--
 tests/test_client_ws_functional.py |   2 +-
 tests/test_http_stream_writer.py   | 257 -----------------------------
 tests/test_http_writer.py          |  89 ++++++----
 tests/test_tcp_helpers.py          | 145 ++++++++++++++++
 tests/test_web_protocol.py         |  54 +++++-
 tests/test_web_sendfile.py         |   8 +-
 tests/test_websocket_writer.py     |  81 +++++----
 19 files changed, 426 insertions(+), 483 deletions(-)
 create mode 100644 CHANGES/2651.removal
 create mode 100644 aiohttp/tcp_helpers.py
 delete mode 100644 tests/test_http_stream_writer.py
 create mode 100644 tests/test_tcp_helpers.py

diff --git a/CHANGES/2651.removal b/CHANGES/2651.removal
new file mode 100644
index 00000000000..0b5f76fd8b6
--- /dev/null
+++ b/CHANGES/2651.removal
@@ -0,0 +1 @@
+Get rid of the legacy class StreamWriter.
diff --git a/aiohttp/client.py b/aiohttp/client.py
index 369eb195454..f259935c536 100644
--- a/aiohttp/client.py
+++ b/aiohttp/client.py
@@ -29,6 +29,7 @@
 from .http import WS_KEY, WebSocketReader, WebSocketWriter
 from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse
 from .streams import FlowControlDataQueue
+from .tcp_helpers import tcp_cork, tcp_nodelay
 from .tracing import Trace
 
 
@@ -296,7 +297,8 @@ async def _request(self, method, url, *,
                             'Connection timeout '
                             'to host {0}'.format(url)) from exc
 
-                    conn.writer.set_tcp_nodelay(True)
+                    tcp_nodelay(conn.transport, True)
+                    tcp_cork(conn.transport, False)
                     try:
                         resp = req.send(conn)
                         try:
@@ -575,12 +577,13 @@ async def _ws_connect(self, url, *,
                     notakeover = False
 
             proto = resp.connection.protocol
+            transport = resp.connection.transport
             reader = FlowControlDataQueue(
                 proto, limit=2 ** 16, loop=self._loop)
             proto.set_parser(WebSocketReader(reader), reader)
-            resp.connection.writer.set_tcp_nodelay(True)
+            tcp_nodelay(transport, True)
             writer = WebSocketWriter(
-                resp.connection.writer, use_mask=True,
+                proto, transport, use_mask=True,
                 compress=compress, notakeover=notakeover)
         except Exception:
             resp.close()
diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py
index 5c51224fd9b..cf0d2f83306 100644
--- a/aiohttp/client_proto.py
+++ b/aiohttp/client_proto.py
@@ -4,7 +4,7 @@
 
 from .client_exceptions import (ClientOSError, ClientPayloadError,
                                 ServerDisconnectedError)
-from .http import HttpResponseParser, StreamWriter
+from .http import HttpResponseParser
 from .streams import EMPTY_PAYLOAD, DataQueue
 
 
@@ -17,7 +17,6 @@ def __init__(self, *, loop=None):
 
         self.paused = False
         self.transport = None
-        self.writer = None
         self._should_close = False
 
         self._message = None
@@ -60,7 +59,6 @@ def is_connected(self):
 
     def connection_made(self, transport):
         self.transport = transport
-        self.writer = StreamWriter(self, transport, self._loop)
 
     def connection_lost(self, exc):
         if self._payload_parser is not None:
@@ -82,7 +80,7 @@ def connection_lost(self, exc):
                 exc = ServerDisconnectedError(uncompleted)
             DataQueue.set_exception(self, exc)
 
-        self.transport = self.writer = None
+        self.transport = None
         self._should_close = True
         self._parser = None
         self._message = None
diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py
index de32f6622e1..2384ca3c0ca 100644
--- a/aiohttp/client_reqrep.py
+++ b/aiohttp/client_reqrep.py
@@ -469,7 +469,7 @@ def send(self, conn):
             if self.url.raw_query_string:
                 path += '?' + self.url.raw_query_string
 
-        writer = PayloadWriter(conn.writer, self.loop)
+        writer = PayloadWriter(conn.protocol, conn.transport, self.loop)
 
         if self.compress:
             writer.enable_compression(self.compress)
diff --git a/aiohttp/http.py b/aiohttp/http.py
index 4dee43b631c..c372426754d 100644
--- a/aiohttp/http.py
+++ b/aiohttp/http.py
@@ -12,7 +12,7 @@
                              WSCloseCode, WSMessage, WSMsgType, ws_ext_gen,
                              ws_ext_parse)
 from .http_writer import (HttpVersion, HttpVersion10, HttpVersion11,
-                          PayloadWriter, StreamWriter)
+                          PayloadWriter)
 
 
 __all__ = (
@@ -20,7 +20,6 @@
 
     # .http_writer
     'PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11',
-    'StreamWriter',
 
     # .http_parser
     'HttpParser', 'HttpRequestParser', 'HttpResponseParser',
diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py
index 5bb51d69f9d..a5ca686f64e 100644
--- a/aiohttp/http_websocket.py
+++ b/aiohttp/http_websocket.py
@@ -513,11 +513,11 @@ def parse_frame(self, buf):
 
 class WebSocketWriter:
 
-    def __init__(self, stream, *,
+    def __init__(self, protocol, transport, *,
                  use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(),
                  compress=0, notakeover=False):
-        self.stream = stream
-        self.writer = stream.transport
+        self.protocol = protocol
+        self.transport = transport
         self.use_mask = use_mask
         self.randrange = random.randrange
         self.compress = compress
@@ -572,20 +572,20 @@ def _send_frame(self, message, opcode, compress=None):
             mask = mask.to_bytes(4, 'big')
             message = bytearray(message)
             _websocket_mask(mask, message)
-            self.writer.write(header + mask + message)
+            self.transport.write(header + mask + message)
             self._output_size += len(header) + len(mask) + len(message)
         else:
             if len(message) > MSG_SIZE:
-                self.writer.write(header)
-                self.writer.write(message)
+                self.transport.write(header)
+                self.transport.write(message)
             else:
-                self.writer.write(header + message)
+                self.transport.write(header + message)
 
             self._output_size += len(header) + len(message)
 
         if self._output_size > self._limit:
             self._output_size = 0
-            return self.stream.drain()
+            return self.protocol._drain_helper()
 
         return noop()
 
diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py
index b253d7ed946..4b83a3ecd80 100644
--- a/aiohttp/http_writer.py
+++ b/aiohttp/http_writer.py
@@ -2,104 +2,24 @@
 
 import asyncio
 import collections
-import socket
 import zlib
-from contextlib import suppress
 
 from .abc import AbstractPayloadWriter
 from .helpers import noop
 
 
-__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11',
-           'StreamWriter')
+__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11')
 
 HttpVersion = collections.namedtuple('HttpVersion', ['major', 'minor'])
 HttpVersion10 = HttpVersion(1, 0)
 HttpVersion11 = HttpVersion(1, 1)
 
 
-if hasattr(socket, 'TCP_CORK'):  # pragma: no cover
-    CORK = socket.TCP_CORK
-elif hasattr(socket, 'TCP_NOPUSH'):  # pragma: no cover
-    CORK = socket.TCP_NOPUSH
-else:  # pragma: no cover
-    CORK = None
-
-
-class StreamWriter:
+class PayloadWriter(AbstractPayloadWriter):
 
     def __init__(self, protocol, transport, loop):
         self._protocol = protocol
-        self._loop = loop
-        self._tcp_nodelay = False
-        self._tcp_cork = False
-        self._socket = transport.get_extra_info('socket')
-        self._waiters = []
-        self.transport = transport
-
-    @property
-    def tcp_nodelay(self):
-        return self._tcp_nodelay
-
-    def set_tcp_nodelay(self, value):
-        value = bool(value)
-        if self._tcp_nodelay == value:
-            return
-        if self._socket is None:
-            return
-        if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
-            return
-
-        # socket may be closed already, on windows OSError get raised
-        with suppress(OSError):
-            if self._tcp_cork:
-                if CORK is not None:  # pragma: no branch
-                    self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False)
-                    self._tcp_cork = False
-
-            self._socket.setsockopt(
-                socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
-            self._tcp_nodelay = value
-
-    @property
-    def tcp_cork(self):
-        return self._tcp_cork
-
-    def set_tcp_cork(self, value):
-        value = bool(value)
-        if self._tcp_cork == value:
-            return
-        if self._socket is None:
-            return
-        if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
-            return
-
-        with suppress(OSError):
-            if self._tcp_nodelay:
-                self._socket.setsockopt(
-                    socket.IPPROTO_TCP, socket.TCP_NODELAY, False)
-                self._tcp_nodelay = False
-            if CORK is not None:  # pragma: no branch
-                self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value)
-                self._tcp_cork = value
-
-    async def drain(self):
-        """Flush the write buffer.
-
-        The intended use is to write
-
-          await w.write(data)
-          await w.drain()
-        """
-        if self._protocol.transport is not None:
-            await self._protocol._drain_helper()
-
-
-class PayloadWriter(AbstractPayloadWriter):
-
-    def __init__(self, stream, loop):
-        self._stream = stream
-        self._transport = None
+        self._transport = transport
 
         self.loop = loop
         self.length = None
@@ -110,11 +30,15 @@ def __init__(self, stream, loop):
         self._eof = False
         self._compress = None
         self._drain_waiter = None
-        self._transport = self._stream.transport
 
-    async def get_transport(self):
+    @property
+    def transport(self):
         return self._transport
 
+    @property
+    def protocol(self):
+        return self._protocol
+
     def enable_chunking(self):
         self.chunked = True
 
@@ -204,4 +128,12 @@ async def write_eof(self, chunk=b''):
         self._transport = None
 
     async def drain(self):
-        await self._stream.drain()
+        """Flush the write buffer.
+
+        The intended use is to write
+
+          await w.write(data)
+          await w.drain()
+        """
+        if self._protocol.transport is not None:
+            await self._protocol._drain_helper()
diff --git a/aiohttp/tcp_helpers.py b/aiohttp/tcp_helpers.py
new file mode 100644
index 00000000000..3a016901c9d
--- /dev/null
+++ b/aiohttp/tcp_helpers.py
@@ -0,0 +1,61 @@
+"""Helper methods to tune a TCP connection"""
+
+import socket
+from contextlib import suppress
+
+
+__all__ = ('tcp_keepalive', 'tcp_nodelay', 'tcp_cork')
+
+
+if hasattr(socket, 'TCP_CORK'):  # pragma: no cover
+    CORK = socket.TCP_CORK
+elif hasattr(socket, 'TCP_NOPUSH'):  # pragma: no cover
+    CORK = socket.TCP_NOPUSH
+else:  # pragma: no cover
+    CORK = None
+
+
+if hasattr(socket, 'SO_KEEPALIVE'):
+    def tcp_keepalive(transport):
+        sock = transport.get_extra_info('socket')
+        if sock is not None:
+            sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+else:
+    def tcp_keepalive(transport):  # pragma: no cover
+        pass
+
+
+def tcp_nodelay(transport, value):
+    sock = transport.get_extra_info('socket')
+
+    if sock is None:
+        return
+
+    if sock.family not in (socket.AF_INET, socket.AF_INET6):
+        return
+
+    value = bool(value)
+
+    # socket may be closed already, on windows OSError get raised
+    with suppress(OSError):
+        sock.setsockopt(
+            socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
+
+
+def tcp_cork(transport, value):
+    sock = transport.get_extra_info('socket')
+
+    if CORK is None:
+        return
+
+    if sock is None:
+        return
+
+    if sock.family not in (socket.AF_INET, socket.AF_INET6):
+        return
+
+    value = bool(value)
+
+    with suppress(OSError):
+        sock.setsockopt(
+            socket.IPPROTO_TCP, CORK, value)
diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py
index 9eebae3fbb7..57f4b21e3ca 100644
--- a/aiohttp/web_fileresponse.py
+++ b/aiohttp/web_fileresponse.py
@@ -54,9 +54,7 @@ def _sendfile_cb(self, fut, out_fd, in_fd,
             set_result(fut, None)
 
     async def sendfile(self, fobj, count):
-        transport = await self.get_transport()
-
-        out_socket = transport.get_extra_info('socket').dup()
+        out_socket = self.transport.get_extra_info('socket').dup()
         out_socket.setblocking(False)
         out_fd = out_socket.fileno()
         in_fd = fobj.fileno()
@@ -71,7 +69,7 @@ async def sendfile(self, fobj, count):
             await fut
         except Exception:
             server_logger.debug('Socket error')
-            transport.close()
+            self.transport.close()
         finally:
             out_socket.close()
 
@@ -112,7 +110,8 @@ async def _sendfile_system(self, request, fobj, count):
             writer = await self._sendfile_fallback(request, fobj, count)
         else:
             writer = SendfilePayloadWriter(
-                request._protocol.writer,
+                request.protocol,
+                transport,
                 request.loop
             )
             request._payload_writer = writer
diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py
index 76c944150cb..4b636cd0e34 100644
--- a/aiohttp/web_protocol.py
+++ b/aiohttp/web_protocol.py
@@ -1,7 +1,6 @@
 import asyncio
 import asyncio.streams
 import http.server
-import socket
 import traceback
 import warnings
 from collections import deque
@@ -10,10 +9,10 @@
 
 from . import helpers, http
 from .helpers import CeilTimeout
-from .http import (HttpProcessingError, HttpRequestParser, PayloadWriter,
-                   StreamWriter)
+from .http import HttpProcessingError, HttpRequestParser, PayloadWriter
 from .log import access_logger, server_logger
 from .streams import EMPTY_PAYLOAD
+from .tcp_helpers import tcp_cork, tcp_keepalive, tcp_nodelay
 from .web_exceptions import HTTPException
 from .web_request import BaseRequest
 from .web_response import Response
@@ -25,15 +24,6 @@
     'UNKNOWN', '/', http.HttpVersion10, {},
     {}, True, False, False, False, http.URL('/'))
 
-if hasattr(socket, 'SO_KEEPALIVE'):
-    def tcp_keepalive(server, transport):
-        sock = transport.get_extra_info('socket')
-        if sock is not None:
-            sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
-else:
-    def tcp_keepalive(server, transport):  # pragma: no cover
-        pass
-
 
 class RequestPayloadError(Exception):
     """Payload parsing error."""
@@ -181,13 +171,12 @@ def connection_made(self, transport):
         super().connection_made(transport)
 
         self.transport = transport
-        self.writer = StreamWriter(self, transport, self._loop)
 
         if self._tcp_keepalive:
-            tcp_keepalive(self, transport)
+            tcp_keepalive(transport)
 
-        self.writer.set_tcp_cork(False)
-        self.writer.set_tcp_nodelay(True)
+        tcp_cork(transport, False)
+        tcp_nodelay(transport, True)
         self._manager.connection_made(self, transport)
 
     def connection_lost(self, exc):
@@ -200,7 +189,7 @@ def connection_lost(self, exc):
         self._request_factory = None
         self._request_handler = None
         self._request_parser = None
-        self.transport = self.writer = None
+        self.transport = None
 
         if self._keepalive_handle is not None:
             self._keepalive_handle.cancel()
@@ -241,14 +230,14 @@ def data_received(self, data):
                 # something happened during parsing
                 self._error_handler = self._loop.create_task(
                     self.handle_parse_error(
-                        PayloadWriter(self.writer, self._loop),
+                        PayloadWriter(self, self.transport, self._loop),
                         400, exc, exc.message))
                 self.close()
             except Exception as exc:
                 # 500: internal error
                 self._error_handler = self._loop.create_task(
                     self.handle_parse_error(
-                        PayloadWriter(self.writer, self._loop),
+                        PayloadWriter(self, self.transport, self._loop),
                         500, exc))
                 self.close()
             else:
@@ -371,7 +360,7 @@ async def start(self):
                 now = loop.time()
 
             manager.requests_count += 1
-            writer = PayloadWriter(self.writer, loop)
+            writer = PayloadWriter(self, self.transport, loop)
             request = self._request_factory(
                 message, payload, self, writer, handler)
             try:
diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py
index 63d74bfdda5..ceaf09b4c56 100644
--- a/aiohttp/web_ws.py
+++ b/aiohttp/web_ws.py
@@ -187,7 +187,8 @@ def _pre_start(self, request):
         self.headers.update(headers)
         self.force_close()
         self._compress = compress
-        writer = WebSocketWriter(request._protocol.writer,
+        writer = WebSocketWriter(request._protocol,
+                                 request._protocol.transport,
                                  compress=compress,
                                  notakeover=notakeover)
 
diff --git a/tests/test_client_request.py b/tests/test_client_request.py
index 5345e73ecfd..12f006bf00c 100644
--- a/tests/test_client_request.py
+++ b/tests/test_client_request.py
@@ -38,6 +38,14 @@ def buf():
     return bytearray()
 
 
+@pytest.fixture
+def protocol(loop):
+    protocol = mock.Mock()
+    protocol._drain_helper.return_value = loop.create_future()
+    protocol._drain_helper.return_value.set_result(None)
+    return protocol
+
+
 @pytest.yield_fixture
 def transport(buf):
     transport = mock.Mock()
@@ -56,22 +64,11 @@ async def write_eof():
 
 
 @pytest.fixture
-def conn(stream):
-    return mock.Mock(writer=stream)
-
-
-@pytest.fixture
-def stream(buf, transport, loop):
-    stream = mock.Mock()
-    stream.transport = transport
-
-    def acquire(writer):
-        writer.set_transport(transport)
-
-    stream.acquire.side_effect = acquire
-    stream.drain.return_value = loop.create_future()
-    stream.drain.return_value.set_result(None)
-    return stream
+def conn(transport, protocol):
+    return mock.Mock(
+        transport=transport,
+        protocol=protocol
+    )
 
 
 def test_method1(make_request):
@@ -845,7 +842,6 @@ def gen(writer):
     assert asyncio.isfuture(req._writer)
     await resp.wait_for_close()
     assert req._writer is None
-
     assert buf.split(b'\r\n\r\n', 1)[1] == \
         b'b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n'
     await req.close()
diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py
index ca395b9efe7..214a7c085a3 100644
--- a/tests/test_client_ws_functional.py
+++ b/tests/test_client_ws_functional.py
@@ -440,7 +440,7 @@ async def handler(request):
         await ws.prepare(request)
 
         await ws.receive_str()
-        ws._writer.writer.write(b'01234' * 100)
+        ws._writer.transport.write(b'01234' * 100)
         await ws.close()
         return ws
 
diff --git a/tests/test_http_stream_writer.py b/tests/test_http_stream_writer.py
deleted file mode 100644
index b4fdb2288a5..00000000000
--- a/tests/test_http_stream_writer.py
+++ /dev/null
@@ -1,257 +0,0 @@
-import socket
-from unittest import mock
-
-import pytest
-
-from aiohttp.http_writer import CORK, StreamWriter
-
-
-has_ipv6 = socket.has_ipv6
-if has_ipv6:
-    # The socket.has_ipv6 flag may be True if Python was built with IPv6
-    # support, but the target system still may not have it.
-    # So let's ensure that we really have IPv6 support.
-    try:
-        socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
-    except OSError:
-        has_ipv6 = False
-
-
-# nodelay
-
-def test_nodelay_and_cork_default(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    assert not writer.tcp_nodelay
-    assert not writer.tcp_cork
-    assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
-
-
-def test_set_nodelay_no_change(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_nodelay(False)
-    assert not writer.tcp_nodelay
-    assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
-
-
-def test_set_nodelay_exception(loop):
-    transport = mock.Mock()
-    s = mock.Mock()
-    s.setsockopt = mock.Mock()
-    s.family = socket.AF_INET
-    s.setsockopt.side_effect = OSError
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_nodelay(True)
-    assert not writer.tcp_nodelay
-
-
-def test_set_nodelay_enable(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_nodelay(True)
-    assert writer.tcp_nodelay
-    assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
-
-
-def test_set_nodelay_enable_and_disable(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_nodelay(True)
-    writer.set_tcp_nodelay(False)
-    assert not writer.tcp_nodelay
-    assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
-
-
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_set_nodelay_and_cork(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_cork(True)
-    writer.set_tcp_nodelay(True)
-    assert writer.tcp_nodelay
-    assert not writer.tcp_cork
-    assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
-
-
-@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available")
-def test_set_nodelay_enable_ipv6(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_nodelay(True)
-    assert writer.tcp_nodelay
-    assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
-
-
-@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
-                    reason="requires unix sockets")
-def test_set_nodelay_enable_unix(loop):
-    # do not set nodelay for unix socket
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_nodelay(True)
-    assert not writer.tcp_nodelay
-
-
-def test_set_nodelay_enable_no_socket(loop):
-    transport = mock.Mock()
-    transport.get_extra_info.return_value = None
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_nodelay(True)
-    assert not writer.tcp_nodelay
-    assert writer._socket is None
-
-
-# cork
-
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_cork_default(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    assert not writer.tcp_cork
-    assert not s.getsockopt(socket.IPPROTO_TCP, CORK)
-
-
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_set_cork_no_change(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_cork(False)
-    assert not writer.tcp_cork
-    assert not s.getsockopt(socket.IPPROTO_TCP, CORK)
-
-
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_set_cork_enable(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_cork(True)
-    assert writer.tcp_cork
-    assert s.getsockopt(socket.IPPROTO_TCP, CORK)
-
-
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_set_cork_enable_and_disable(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_cork(True)
-    writer.set_tcp_cork(False)
-    assert not writer.tcp_cork
-    assert not s.getsockopt(socket.IPPROTO_TCP, CORK)
-
-
-@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available")
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_set_cork_enable_ipv6(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_cork(True)
-    assert writer.tcp_cork
-    assert s.getsockopt(socket.IPPROTO_TCP, CORK)
-
-
-@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
-                    reason="requires unix sockets")
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_set_cork_enable_unix(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_cork(True)
-    assert not writer.tcp_cork
-
-
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_set_cork_enable_no_socket(loop):
-    transport = mock.Mock()
-    transport.get_extra_info.return_value = None
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_cork(True)
-    assert not writer.tcp_cork
-    assert writer._socket is None
-
-
-def test_set_cork_exception(loop):
-    transport = mock.Mock()
-    s = mock.Mock()
-    s.setsockopt = mock.Mock()
-    s.family = socket.AF_INET
-    s.setsockopt.side_effect = OSError
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_cork(True)
-    assert not writer.tcp_cork
-
-
-# cork and nodelay interference
-
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_set_enabling_cork_disables_nodelay(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_nodelay(True)
-    writer.set_tcp_cork(True)
-    assert not writer.tcp_nodelay
-    assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
-    assert writer.tcp_cork
-    assert s.getsockopt(socket.IPPROTO_TCP, CORK)
-
-
-@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
-def test_set_enabling_nodelay_disables_cork(loop):
-    transport = mock.Mock()
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    transport.get_extra_info.return_value = s
-    proto = mock.Mock()
-    writer = StreamWriter(proto, transport, loop)
-    writer.set_tcp_cork(True)
-    writer.set_tcp_nodelay(True)
-    assert writer.tcp_nodelay
-    assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
-    assert not writer.tcp_cork
-    assert not s.getsockopt(socket.IPPROTO_TCP, CORK)
diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py
index c54f609e3f8..317794e7ab8 100644
--- a/tests/test_http_writer.py
+++ b/tests/test_http_writer.py
@@ -26,21 +26,22 @@ def write(chunk):
 
 
 @pytest.fixture
-def stream(transport, loop):
-    stream = mock.Mock(transport=transport)
+def protocol(loop, transport):
+    protocol = mock.Mock(transport=transport)
+    protocol._drain_helper.return_value = loop.create_future()
+    protocol._drain_helper.return_value.set_result(None)
+    return protocol
 
-    def acquire(writer):
-        writer.set_transport(transport)
 
-    stream.acquire = acquire
-    stream.drain.return_value = loop.create_future()
-    stream.drain.return_value.set_result(None)
-    return stream
+def test_payloadwriter_properties(transport, protocol, loop):
+    writer = http.PayloadWriter(protocol, transport, loop)
+    assert writer.protocol == protocol
+    assert writer.transport == transport
 
 
-async def test_write_payload_eof(stream, loop):
-    write = stream.transport.write = mock.Mock()
-    msg = http.PayloadWriter(stream, loop)
+async def test_write_payload_eof(transport, protocol, loop):
+    write = transport.write = mock.Mock()
+    msg = http.PayloadWriter(protocol, transport, loop)
 
     msg.write(b'data1')
     msg.write(b'data2')
@@ -50,8 +51,8 @@ async def test_write_payload_eof(stream, loop):
     assert b'data1data2' == content.split(b'\r\n\r\n', 1)[-1]
 
 
-async def test_write_payload_chunked(buf, stream, loop):
-    msg = http.PayloadWriter(stream, loop)
+async def test_write_payload_chunked(buf, protocol, transport, loop):
+    msg = http.PayloadWriter(protocol, transport, loop)
     msg.enable_chunking()
     msg.write(b'data')
     await msg.write_eof()
@@ -59,8 +60,8 @@ async def test_write_payload_chunked(buf, stream, loop):
     assert b'4\r\ndata\r\n0\r\n\r\n' == buf
 
 
-async def test_write_payload_chunked_multiple(buf, stream, loop):
-    msg = http.PayloadWriter(stream, loop)
+async def test_write_payload_chunked_multiple(buf, protocol, transport, loop):
+    msg = http.PayloadWriter(protocol, transport, loop)
     msg.enable_chunking()
     msg.write(b'data1')
     msg.write(b'data2')
@@ -69,10 +70,10 @@ async def test_write_payload_chunked_multiple(buf, stream, loop):
     assert b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n' == buf
 
 
-async def test_write_payload_length(stream, loop):
-    write = stream.transport.write = mock.Mock()
+async def test_write_payload_length(protocol, transport, loop):
+    write = transport.write = mock.Mock()
 
-    msg = http.PayloadWriter(stream, loop)
+    msg = http.PayloadWriter(protocol, transport, loop)
     msg.length = 2
     msg.write(b'd')
     msg.write(b'ata')
@@ -82,10 +83,10 @@ async def test_write_payload_length(stream, loop):
     assert b'da' == content.split(b'\r\n\r\n', 1)[-1]
 
 
-async def test_write_payload_chunked_filter(stream, loop):
-    write = stream.transport.write = mock.Mock()
+async def test_write_payload_chunked_filter(protocol, transport, loop):
+    write = transport.write = mock.Mock()
 
-    msg = http.PayloadWriter(stream, loop)
+    msg = http.PayloadWriter(protocol, transport, loop)
     msg.enable_chunking()
     msg.write(b'da')
     msg.write(b'ta')
@@ -95,9 +96,12 @@ async def test_write_payload_chunked_filter(stream, loop):
     assert content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')
 
 
-async def test_write_payload_chunked_filter_mutiple_chunks(stream, loop):
-    write = stream.transport.write = mock.Mock()
-    msg = http.PayloadWriter(stream, loop)
+async def test_write_payload_chunked_filter_mutiple_chunks(
+        protocol,
+        transport,
+        loop):
+    write = transport.write = mock.Mock()
+    msg = http.PayloadWriter(protocol, transport, loop)
     msg.enable_chunking()
     msg.write(b'da')
     msg.write(b'ta')
@@ -115,9 +119,9 @@ async def test_write_payload_chunked_filter_mutiple_chunks(stream, loop):
 COMPRESSED = b''.join([compressor.compress(b'data'), compressor.flush()])
 
 
-async def test_write_payload_deflate_compression(stream, loop):
-    write = stream.transport.write = mock.Mock()
-    msg = http.PayloadWriter(stream, loop)
+async def test_write_payload_deflate_compression(protocol, transport, loop):
+    write = transport.write = mock.Mock()
+    msg = http.PayloadWriter(protocol, transport, loop)
     msg.enable_compression('deflate')
     msg.write(b'data')
     await msg.write_eof()
@@ -128,8 +132,12 @@ async def test_write_payload_deflate_compression(stream, loop):
     assert COMPRESSED == content.split(b'\r\n\r\n', 1)[-1]
 
 
-async def test_write_payload_deflate_and_chunked(buf, stream, loop):
-    msg = http.PayloadWriter(stream, loop)
+async def test_write_payload_deflate_and_chunked(
+        buf,
+        protocol,
+        transport,
+        loop):
+    msg = http.PayloadWriter(protocol, transport, loop)
     msg.enable_compression('deflate')
     msg.enable_chunking()
 
@@ -140,8 +148,8 @@ async def test_write_payload_deflate_and_chunked(buf, stream, loop):
     assert b'6\r\nKI,I\x04\x00\r\n0\r\n\r\n' == buf
 
 
-def test_write_drain(stream, loop):
-    msg = http.PayloadWriter(stream, loop)
+def test_write_drain(protocol, transport, loop):
+    msg = http.PayloadWriter(protocol, transport, loop)
     msg.drain = mock.Mock()
     msg.write(b'1' * (64 * 1024 * 2), drain=False)
     assert not msg.drain.called
@@ -151,11 +159,24 @@ def test_write_drain(stream, loop):
     assert msg.buffer_size == 0
 
 
-def test_write_to_closing_transport(stream, loop):
-    msg = http.PayloadWriter(stream, loop)
+def test_write_to_closing_transport(protocol, transport, loop):
+    msg = http.PayloadWriter(protocol, transport, loop)
 
     msg.write(b'Before closing')
-    stream.transport.is_closing.return_value = True
+    transport.is_closing.return_value = True
 
     with pytest.raises(asyncio.CancelledError):
         msg.write(b'After closing')
+
+
+async def test_drain(protocol, transport, loop):
+    msg = http.PayloadWriter(protocol, transport, loop)
+    await msg.drain()
+    assert protocol._drain_helper.called
+
+
+async def test_drain_no_transport(protocol, transport, loop):
+    msg = http.PayloadWriter(protocol, transport, loop)
+    msg._protocol.transport = None
+    await msg.drain()
+    assert not protocol._drain_helper.called
diff --git a/tests/test_tcp_helpers.py b/tests/test_tcp_helpers.py
new file mode 100644
index 00000000000..ebe8271d820
--- /dev/null
+++ b/tests/test_tcp_helpers.py
@@ -0,0 +1,145 @@
+import socket
+from unittest import mock
+
+import pytest
+
+from aiohttp.tcp_helpers import CORK, tcp_cork, tcp_nodelay
+
+
+has_ipv6 = socket.has_ipv6
+if has_ipv6:
+    # The socket.has_ipv6 flag may be True if Python was built with IPv6
+    # support, but the target system still may not have it.
+    # So let's ensure that we really have IPv6 support.
+    try:
+        socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+    except OSError:
+        has_ipv6 = False
+
+
+# nodelay
+
+def test_tcp_nodelay_exception(loop):
+    transport = mock.Mock()
+    s = mock.Mock()
+    s.setsockopt = mock.Mock()
+    s.family = socket.AF_INET
+    s.setsockopt.side_effect = OSError
+    transport.get_extra_info.return_value = s
+    tcp_nodelay(transport, True)
+    s.setsockopt.assert_called_with(
+        socket.IPPROTO_TCP,
+        socket.TCP_NODELAY,
+        True
+    )
+
+
+def test_tcp_nodelay_enable(loop):
+    transport = mock.Mock()
+    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    transport.get_extra_info.return_value = s
+    tcp_nodelay(transport, True)
+    assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
+
+
+def test_tcp_nodelay_enable_and_disable(loop):
+    transport = mock.Mock()
+    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    transport.get_extra_info.return_value = s
+    tcp_nodelay(transport, True)
+    assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
+    tcp_nodelay(transport, False)
+    assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
+
+
+@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available")
+def test_tcp_nodelay_enable_ipv6(loop):
+    transport = mock.Mock()
+    s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+    transport.get_extra_info.return_value = s
+    tcp_nodelay(transport, True)
+    assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
+
+
+@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
+                    reason="requires unix sockets")
+def test_tcp_nodelay_enable_unix(loop):
+    # do not set nodelay for unix socket
+    transport = mock.Mock()
+    s = mock.Mock(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
+    transport.get_extra_info.return_value = s
+    tcp_nodelay(transport, True)
+    assert not s.setsockopt.called
+
+
+def test_tcp_nodelay_enable_no_socket(loop):
+    transport = mock.Mock()
+    transport.get_extra_info.return_value = None
+    tcp_nodelay(transport, True)
+
+
+# cork
+
+
+@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
+def test_tcp_cork_enable(loop):
+    transport = mock.Mock()
+    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    transport.get_extra_info.return_value = s
+    tcp_cork(transport, True)
+    assert s.getsockopt(socket.IPPROTO_TCP, CORK)
+
+
+@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
+def test_set_cork_enable_and_disable(loop):
+    transport = mock.Mock()
+    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    transport.get_extra_info.return_value = s
+    tcp_cork(transport, True)
+    assert s.getsockopt(socket.IPPROTO_TCP, CORK)
+    tcp_cork(transport, False)
+    assert not s.getsockopt(socket.IPPROTO_TCP, CORK)
+
+
+@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available")
+@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
+def test_set_cork_enable_ipv6(loop):
+    transport = mock.Mock()
+    s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+    transport.get_extra_info.return_value = s
+    tcp_cork(transport, True)
+    assert s.getsockopt(socket.IPPROTO_TCP, CORK)
+
+
+@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
+                    reason="requires unix sockets")
+@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
+def test_set_cork_enable_unix(loop):
+    transport = mock.Mock()
+    s = mock.Mock(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
+    transport.get_extra_info.return_value = s
+    tcp_cork(transport, True)
+    assert not s.setsockopt.called
+
+
+@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
+def test_set_cork_enable_no_socket(loop):
+    transport = mock.Mock()
+    transport.get_extra_info.return_value = None
+    tcp_cork(transport, True)
+
+
+@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required")
+def test_set_cork_exception(loop):
+    transport = mock.Mock()
+    s = mock.Mock()
+    s.setsockopt = mock.Mock()
+    s.family = socket.AF_INET
+    s.setsockopt.side_effect = OSError
+    transport.get_extra_info.return_value = s
+    tcp_cork(transport, True)
+    s.setsockopt.assert_called_with(
+        socket.IPPROTO_TCP,
+        CORK,
+        True
+    )
diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py
index 00a772afe9f..e68ab144242 100644
--- a/tests/test_web_protocol.py
+++ b/tests/test_web_protocol.py
@@ -38,6 +38,8 @@ def srv(make_srv, transport):
     srv = make_srv()
     srv.connection_made(transport)
     transport.close.side_effect = partial(srv.connection_lost, None)
+    srv._drain_helper = mock.Mock()
+    srv._drain_helper.side_effect = helpers.noop
     return srv
 
 
@@ -72,7 +74,7 @@ async def handle(request):
 
 @pytest.yield_fixture
 def writer(srv):
-    return http.PayloadWriter(srv.writer, srv._loop)
+    return http.PayloadWriter(srv, srv.transport, srv._loop)
 
 
 @pytest.yield_fixture
@@ -83,7 +85,6 @@ def write(chunk):
         buf.extend(chunk)
 
     transport.write.side_effect = write
-    transport.drain.side_effect = helpers.noop
     transport.is_closing.return_value = False
 
     return transport
@@ -131,6 +132,16 @@ async def test_double_shutdown(srv, transport):
     assert srv.transport is None
 
 
+async def test_shutdown_wait_error_handler(loop, srv, transport):
+
+    async def _error_handle():
+        pass
+
+    srv._error_handler = loop.create_task(_error_handle())
+    await srv.shutdown()
+    assert srv._error_handler.done()
+
+
 async def test_close_after_response(srv, loop, transport):
     srv.data_received(
         b'GET / HTTP/1.0\r\n'
@@ -227,7 +238,7 @@ async def test_bad_method(srv, loop, buf):
 
 
 async def test_data_received_error(srv, loop, buf):
-    srv.transport = mock.Mock()
+    transport = srv.transport
     srv._request_parser = mock.Mock()
     srv._request_parser.feed_data.side_effect = TypeError
 
@@ -237,7 +248,7 @@ async def test_data_received_error(srv, loop, buf):
 
     await asyncio.sleep(0, loop=loop)
     assert buf.startswith(b'HTTP/1.0 500 Internal Server Error\r\n')
-    assert srv.transport.close.called
+    assert transport.close.called
     assert srv._error_handler is None
 
 
@@ -737,3 +748,38 @@ def test_data_received_force_close(srv):
         b'Content-Length: 0\r\n\r\n')
 
     assert not srv._messages
+
+
+async def test__process_keepalive(loop, srv):
+    # wait till the waiter is waiting
+    await asyncio.sleep(0)
+
+    srv._keepalive_time = 1
+    srv._keepalive_timeout = 1
+    expired_time = srv._keepalive_time + srv._keepalive_timeout + 1
+    with mock.patch.object(loop, "time", return_value=expired_time):
+        srv._process_keepalive()
+        assert srv._force_close
+
+
+async def test__process_keepalive_schedule_next(loop, srv):
+    # wait till the waiter is waiting
+    await asyncio.sleep(0)
+
+    srv._keepalive_time = 1
+    srv._keepalive_timeout = 1
+    expire_time = srv._keepalive_time + srv._keepalive_timeout
+    with mock.patch.object(loop, "time", return_value=expire_time):
+        with mock.patch.object(loop, "call_at") as call_at_patched:
+            srv._process_keepalive()
+            call_at_patched.assert_called_with(
+                expire_time,
+                srv._process_keepalive
+            )
+
+
+def test__process_keepalive_force_close(loop, srv):
+    srv._force_close = True
+    with mock.patch.object(loop, "call_at") as call_at_patched:
+        srv._process_keepalive()
+        assert not call_at_patched.called
diff --git a/tests/test_web_sendfile.py b/tests/test_web_sendfile.py
index 2bec965893b..7f7520ddb4f 100644
--- a/tests/test_web_sendfile.py
+++ b/tests/test_web_sendfile.py
@@ -12,7 +12,7 @@ def test_static_handle_eof(loop):
         in_fd = 31
         fut = loop.create_future()
         m_os.sendfile.return_value = 0
-        writer = SendfilePayloadWriter(fake_loop, mock.Mock())
+        writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop)
         writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
         m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100)
         assert fut.done()
@@ -28,7 +28,7 @@ def test_static_handle_again(loop):
         in_fd = 31
         fut = loop.create_future()
         m_os.sendfile.side_effect = BlockingIOError()
-        writer = SendfilePayloadWriter(fake_loop, mock.Mock())
+        writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop)
         writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
         m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100)
         assert not fut.done()
@@ -47,7 +47,7 @@ def test_static_handle_exception(loop):
         fut = loop.create_future()
         exc = OSError()
         m_os.sendfile.side_effect = exc
-        writer = SendfilePayloadWriter(fake_loop, mock.Mock())
+        writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop)
         writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
         m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100)
         assert fut.done()
@@ -63,7 +63,7 @@ def test__sendfile_cb_return_on_cancelling(loop):
         in_fd = 31
         fut = loop.create_future()
         fut.cancel()
-        writer = SendfilePayloadWriter(fake_loop, mock.Mock())
+        writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop)
         writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
         assert fut.done()
         assert not fake_loop.add_writer.called
diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py
index a31808549f6..af30b1e3910 100644
--- a/tests/test_websocket_writer.py
+++ b/tests/test_websocket_writer.py
@@ -7,88 +7,97 @@
 
 
 @pytest.fixture
-def stream():
+def protocol():
     return mock.Mock()
 
 
 @pytest.fixture
-def writer(stream):
-    return WebSocketWriter(stream, use_mask=False)
+def transport():
+    return mock.Mock()
+
+
+@pytest.fixture
+def writer(protocol, transport):
+    return WebSocketWriter(protocol, transport, use_mask=False)
 
 
-def test_pong(stream, writer):
+def test_pong(writer):
     writer.pong()
-    stream.transport.write.assert_called_with(b'\x8a\x00')
+    writer.transport.write.assert_called_with(b'\x8a\x00')
 
 
-def test_ping(stream, writer):
+def test_ping(writer):
     writer.ping()
-    stream.transport.write.assert_called_with(b'\x89\x00')
+    writer.transport.write.assert_called_with(b'\x89\x00')
 
 
-def test_send_text(stream, writer):
+def test_send_text(writer):
     writer.send(b'text')
-    stream.transport.write.assert_called_with(b'\x81\x04text')
+    writer.transport.write.assert_called_with(b'\x81\x04text')
 
 
-def test_send_binary(stream, writer):
+def test_send_binary(writer):
     writer.send('binary', True)
-    stream.transport.write.assert_called_with(b'\x82\x06binary')
+    writer.transport.write.assert_called_with(b'\x82\x06binary')
 
 
-def test_send_binary_long(stream, writer):
+def test_send_binary_long(writer):
     writer.send(b'b' * 127, True)
-    assert stream.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')
+    assert writer.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')
 
 
-def test_send_binary_very_long(stream, writer):
+def test_send_binary_very_long(writer):
     writer.send(b'b' * 65537, True)
-    assert (stream.transport.write.call_args_list[0][0][0] ==
+    assert (writer.transport.write.call_args_list[0][0][0] ==
             b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01')
-    assert stream.transport.write.call_args_list[1][0][0] == b'b' * 65537
+    assert writer.transport.write.call_args_list[1][0][0] == b'b' * 65537
 
 
-def test_close(stream, writer):
+def test_close(writer):
     writer.close(1001, 'msg')
-    stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg')
+    writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg')
 
     writer.close(1001, b'msg')
-    stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg')
+    writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg')
 
     # Test that Service Restart close code is also supported
     writer.close(1012, b'msg')
-    stream.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg')
+    writer.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg')
 
 
-def test_send_text_masked(stream):
-    writer = WebSocketWriter(stream,
+def test_send_text_masked(protocol, transport):
+    writer = WebSocketWriter(protocol,
+                             transport,
                              use_mask=True,
                              random=random.Random(123))
     writer.send(b'text')
-    stream.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12')
+    writer.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12')
 
 
-def test_send_compress_text(stream):
-    writer = WebSocketWriter(stream, compress=15)
+def test_send_compress_text(protocol, transport):
+    writer = WebSocketWriter(protocol, transport, compress=15)
     writer.send(b'text')
-    stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
+    writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
     writer.send(b'text')
-    stream.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00')
+    writer.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00')
 
 
-def test_send_compress_text_notakeover(stream):
-    writer = WebSocketWriter(stream, compress=15, notakeover=True)
+def test_send_compress_text_notakeover(protocol, transport):
+    writer = WebSocketWriter(protocol,
+                             transport,
+                             compress=15,
+                             notakeover=True)
     writer.send(b'text')
-    stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
+    writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
     writer.send(b'text')
-    stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
+    writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
 
 
-def test_send_compress_text_per_message(stream):
-    writer = WebSocketWriter(stream)
+def test_send_compress_text_per_message(protocol, transport):
+    writer = WebSocketWriter(protocol, transport)
     writer.send(b'text', compress=15)
-    stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
+    writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
     writer.send(b'text')
-    stream.transport.write.assert_called_with(b'\x81\x04text')
+    writer.transport.write.assert_called_with(b'\x81\x04text')
     writer.send(b'text', compress=15)
-    stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')
+    writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')