From ee4cedb1e7c73af4749f3f7026af70984f529e89 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Mon, 24 Jun 2019 22:48:59 -0700 Subject: [PATCH 1/5] Streams are iterable + receive_some doesn't require an explicit size This came out of discussion in gh-959 --- docs/source/reference-io.rst | 2 +- docs/source/tutorial.rst | 39 ++++++++++--------- docs/source/tutorial/echo-client.py | 11 ++---- docs/source/tutorial/echo-server.py | 12 ++---- newsfragments/959.feature.rst | 8 ++++ notes-to-self/graceful-shutdown-idea.py | 2 +- trio/_abc.py | 23 +++++++---- trio/_highlevel_generic.py | 4 +- trio/_highlevel_socket.py | 10 ++++- trio/_ssl.py | 51 +++++++++++++------------ trio/_subprocess.py | 5 +-- trio/_unix_pipes.py | 18 ++++++--- trio/_windows_pipes.py | 17 ++++++--- trio/testing/_check_streams.py | 11 +++++- trio/testing/_memory_streams.py | 26 ++++++------- trio/tests/test_highlevel_generic.py | 2 +- trio/tests/test_ssl.py | 12 +++--- trio/tests/test_subprocess.py | 7 +--- trio/tests/test_testing.py | 4 +- 19 files changed, 145 insertions(+), 119 deletions(-) create mode 100644 newsfragments/959.feature.rst diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 4bebbad6fa..fea79a4708 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -98,7 +98,7 @@ Abstract base classes * - :class:`ReceiveStream` - :class:`AsyncResource` - :meth:`~ReceiveStream.receive_some` - - + - ``__aiter__``, ``__anext__`` - :class:`~trio.testing.MemoryReceiveStream` * - :class:`Stream` - :class:`SendStream`, :class:`ReceiveStream` diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 27fedc103f..5896933871 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -908,12 +908,10 @@ And the second task's job is to process the data the server sends back: :lineno-match: :pyobject: receiver -It repeatedly calls ``await client_stream.receive_some(...)`` to get -more data from the server (again, all Trio streams provide this -method), and then checks to see if the server has closed the -connection. ``receive_some`` only returns an empty bytestring if the -connection has been closed; otherwise, it waits until some data has -arrived, up to a maximum of ``BUFSIZE`` bytes. +It uses an ``async for`` loop to fetch data from the server. +Alternatively, it could use `~trio.abc.ReceiveStream.receive_some`, +which is the opposite of `~trio.abc.SendStream.send_all`, but using +``async for`` saves some boilerplate. And now we're ready to look at the server. @@ -974,11 +972,11 @@ functions we saw in the last section: The argument ``server_stream`` is provided by :func:`serve_tcp`, and is the other end of the connection we made in the client: so the data -that the client passes to ``send_all`` will come out of -``receive_some`` here, and vice-versa. Then we have a ``try`` block -discussed below, and finally the server loop which alternates between -reading some data from the socket and then sending it back out again -(unless the socket was closed, in which case we quit). +that the client passes to ``send_all`` will come out here. Then we +have a ``try`` block discussed below, and finally the server loop +which alternates between reading some data from the socket and then +sending it back out again (unless the socket was closed, in which case +we quit). So what's that ``try`` block for? Remember that in Trio, like Python in general, exceptions keep propagating until they're caught. Here we @@ -1029,7 +1027,7 @@ our client could use a single task like:: while True: data = ... await client_stream.send_all(data) - received = await client_stream.receive_some(BUFSIZE) + received = await client_stream.receive_some() if not received: sys.exit() await trio.sleep(1) @@ -1046,18 +1044,23 @@ line, any time we're expecting more than one byte of data, we have to be prepared to call ``receive_some`` multiple times. And where this would go especially wrong is if we find ourselves in -the situation where ``len(data) > BUFSIZE``. On each pass through the -loop, we send ``len(data)`` bytes, but only read *at most* ``BUFSIZE`` -bytes. The result is something like a memory leak: we'll end up with -more and more data backed up in the network, until eventually -something breaks. +the situation where ``data`` is big enough that it passes some +internal threshold, and the operating system or network decide to +always break it up into multiple pieces. Now on each pass through the +loop, we send ``len(data)`` bytes, but read less than that. The result +is something like a memory leak: we'll end up with more and more data +backed up in the network, until eventually something breaks. + +.. note:: If you're curious *how* things break, then you can use + `~trio.abc.ReceiveStream.receive_some`\'s optional argument to put + a limit on how many bytes you read each time, and see what happens. We could fix this by keeping track of how much data we're expecting at each moment, and then keep calling ``receive_some`` until we get it all:: expected = len(data) while expected > 0: - received = await client_stream.receive_some(BUFSIZE) + received = await client_stream.receive_some(expected) if not received: sys.exit(1) expected -= len(received) diff --git a/docs/source/tutorial/echo-client.py b/docs/source/tutorial/echo-client.py index f6468f32fc..06f6a81e7e 100644 --- a/docs/source/tutorial/echo-client.py +++ b/docs/source/tutorial/echo-client.py @@ -8,9 +8,6 @@ # - can't be in use by some other program on your computer # - must match what we set in our echo server PORT = 12345 -# How much memory to spend (at most) on each call to recv. Pretty arbitrary, -# but shouldn't be too big or too small. -BUFSIZE = 16384 async def sender(client_stream): print("sender: started!") @@ -22,12 +19,10 @@ async def sender(client_stream): async def receiver(client_stream): print("receiver: started!") - while True: - data = await client_stream.receive_some(BUFSIZE) + async for data in client_stream: print("receiver: got data {!r}".format(data)) - if not data: - print("receiver: connection closed") - sys.exit() + print("receiver: connection closed") + sys.exit() async def parent(): print("parent: connecting to 127.0.0.1:{}".format(PORT)) diff --git a/docs/source/tutorial/echo-server.py b/docs/source/tutorial/echo-server.py index c341925341..08cac7a815 100644 --- a/docs/source/tutorial/echo-server.py +++ b/docs/source/tutorial/echo-server.py @@ -8,9 +8,6 @@ # - can't be in use by some other program on your computer # - must match what we set in our echo client PORT = 12345 -# How much memory to spend (at most) on each call to recv. Pretty arbitrary, -# but shouldn't be too big or too small. -BUFSIZE = 16384 CONNECTION_COUNTER = count() @@ -20,14 +17,11 @@ async def echo_server(server_stream): ident = next(CONNECTION_COUNTER) print("echo_server {}: started".format(ident)) try: - while True: - data = await server_stream.receive_some(BUFSIZE) + async for data in server_stream: print("echo_server {}: received data {!r}".format(ident, data)) - if not data: - print("echo_server {}: connection closed".format(ident)) - return - print("echo_server {}: sending data {!r}".format(ident, data)) await server_stream.send_all(data) + print("echo_server {}: connection closed".format(ident)) + return # FIXME: add discussion of MultiErrors to the tutorial, and use # MultiError.catch here. (Not important in this case, but important if the # server code uses nurseries internally.) diff --git a/newsfragments/959.feature.rst b/newsfragments/959.feature.rst new file mode 100644 index 0000000000..4e5bfc97cd --- /dev/null +++ b/newsfragments/959.feature.rst @@ -0,0 +1,8 @@ +If you have a `~trio.abc.ReceiveStream` object, you can now use +``async for data in stream: ...`` instead of calling +`~trio.abc.ReceiveStream.receive_some` repeatedly. And the best part +is, it automatically checks for EOF for you, so you don't have to. +Also, you no longer have to choose a magic buffer size value before +calling `~trio.abc.ReceiveStream.receive_some`; you can now call +``await stream.receive_some()`` and the stream will automatically pick +a reasonable value for you. diff --git a/notes-to-self/graceful-shutdown-idea.py b/notes-to-self/graceful-shutdown-idea.py index d76402dbf8..792344de02 100644 --- a/notes-to-self/graceful-shutdown-idea.py +++ b/notes-to-self/graceful-shutdown-idea.py @@ -30,7 +30,7 @@ def shutting_down(self): async def stream_handler(stream): while True: with gsm.cancel_on_graceful_shutdown(): - data = await stream.receive_some(...) + data = await stream.receive_some() if gsm.shutting_down: break diff --git a/trio/_abc.py b/trio/_abc.py index 08c145851e..da381b2ed2 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -378,26 +378,26 @@ class ReceiveStream(AsyncResource): If you want to receive Python objects rather than raw bytes, see :class:`ReceiveChannel`. + `ReceiveStream` objects can be used in ``async for`` loops. Each iteration + will produce an arbitrary size + """ __slots__ = () @abstractmethod - async def receive_some(self, max_bytes): + async def receive_some(self, max_bytes=None): """Wait until there is data available on this stream, and then return - at most ``max_bytes`` of it. + some of it. A return value of ``b""`` (an empty bytestring) indicates that the stream has reached end-of-file. Implementations should be careful that they return ``b""`` if, and only if, the stream has reached end-of-file! - This method will return as soon as any data is available, so it may - return fewer than ``max_bytes`` of data. But it will never return - more. - Args: max_bytes (int): The maximum number of bytes to return. Must be - greater than zero. + greater than zero. Optional; if omitted, then the stream object + is free to pick a reasonable default. Returns: bytes or bytearray: The data received. @@ -413,6 +413,15 @@ async def receive_some(self, max_bytes): """ + def __aiter__(self): + return self + + async def __anext__(self): + data = await self.receive_some() + if not data: + raise StopAsyncIteration + return data + class Stream(SendStream, ReceiveStream): """A standard interface for interacting with bidirectional byte streams. diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index 8ff5f6b3c0..601a8ff437 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -52,7 +52,7 @@ class StapledStream(HalfCloseableStream): left, right = trio.testing.memory_stream_pair() echo_stream = StapledStream(SocketStream(left), SocketStream(right)) await echo_stream.send_all(b"x") - assert await echo_stream.receive_some(1) == b"x" + assert await echo_stream.receive_some() == b"x" :class:`StapledStream` objects implement the methods in the :class:`~trio.abc.HalfCloseableStream` interface. They also have two @@ -96,7 +96,7 @@ async def send_eof(self): else: return await self.send_stream.aclose() - async def receive_some(self, max_bytes): + async def receive_some(self, max_bytes=None): """Calls ``self.receive_stream.receive_some``. """ diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index e25f498b9a..47e253e799 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -10,6 +10,12 @@ __all__ = ["SocketStream", "SocketListener"] +# XX TODO: this number was picked arbitrarily. We should do experiments to +# tune it. (Or make it dynamic -- one idea is to start small and increase it +# if we observe single reads filling up the whole buffer, at least within some +# limits.) +DEFAULT_RECEIVE_SIZE = 65536 + _closed_stream_errnos = { # Unix errno.EBADF, @@ -129,7 +135,9 @@ async def send_eof(self): with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes): + async def receive_some(self, max_bytes=None): + if max_bytes is None: + max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: raise ValueError("max_bytes must be >= 1") with _translate_socket_errors_to_stream_errors(): diff --git a/trio/_ssl.py b/trio/_ssl.py index 57a3659290..16d9623243 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -159,11 +159,16 @@ from ._highlevel_generic import aclose_forcefully from . import _sync from ._util import ConflictDetector +from ._deprecate import warn_deprecated ################################################################ # SSLStream ################################################################ +# XX TODO: this number was pulled out of a hat. We should tune it with +# science. +DEFAULT_RECEIVE_SIZE = 65536 + class NeedHandshakeError(Exception): """Some :class:`SSLStream` methods can't return any meaningful data until @@ -197,8 +202,6 @@ def done(self): _State = _Enum("_State", ["OK", "BROKEN", "CLOSED"]) -_default_max_refill_bytes = 32 * 1024 - class SSLStream(Stream): r"""Encrypted communication using SSL/TLS. @@ -269,15 +272,6 @@ class SSLStream(Stream): that :class:`~ssl.SSLSocket` implements the ``https_compatible=True`` behavior by default. - max_refill_bytes (int): :class:`~ssl.SSLSocket` maintains an internal - buffer of incoming data, and when it runs low then it calls - :meth:`receive_some` on the underlying transport stream to refill - it. This argument lets you set the ``max_bytes`` argument passed to - the *underlying* :meth:`receive_some` call. It doesn't affect calls - to *this* class's :meth:`receive_some`, or really anything else - user-observable except possibly performance. You probably don't need - to worry about this. - Attributes: transport_stream (trio.abc.Stream): The underlying transport stream that was passed to ``__init__``. An example of when this would be @@ -313,11 +307,14 @@ def __init__( server_hostname=None, server_side=False, https_compatible=False, - max_refill_bytes=_default_max_refill_bytes + max_refill_bytes="unused and deprecated", ): self.transport_stream = transport_stream self._state = _State.OK - self._max_refill_bytes = max_refill_bytes + if max_refill_bytes != "unused and deprecated": + warn_deprecated( + "max_refill_bytes=...", "0.12.0", issue=959, instead=None + ) self._https_compatible = https_compatible self._outgoing = _stdlib_ssl.MemoryBIO() self._incoming = _stdlib_ssl.MemoryBIO() @@ -536,9 +533,7 @@ async def _retry(self, fn, *args, ignore_want_read=False): async with self._inner_recv_lock: yielded = True if recv_count == self._inner_recv_count: - data = await self.transport_stream.receive_some( - self._max_refill_bytes - ) + data = await self.transport_stream.receive_some() if not data: self._incoming.write_eof() else: @@ -590,7 +585,7 @@ async def do_handshake(self): # https://bugs.python.org/issue30141 # So we *definitely* have to make sure that do_handshake is called # before doing anything else. - async def receive_some(self, max_bytes): + async def receive_some(self, max_bytes=None): """Read some data from the underlying transport, decrypt it, and return it. @@ -621,9 +616,15 @@ async def receive_some(self, max_bytes): return b"" else: raise - max_bytes = _operator.index(max_bytes) - if max_bytes < 1: - raise ValueError("max_bytes must be >= 1") + if max_bytes is None: + # Heuristic: normally we use DEFAULT_RECEIVE_SIZE, but if + # the transport gave us a bunch of data last time then we'll + # try to decrypt and pass it all back at once. + max_bytes = max(DEFAULT_RECEIVE_SIZE, self._incoming.pending) + else: + max_bytes = _operator.index(max_bytes) + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") try: return await self._retry(self._ssl_object.read, max_bytes) except trio.BrokenResourceError as exc: @@ -837,8 +838,6 @@ class SSLListener(Listener[SSLStream]): https_compatible (bool): Passed on to :class:`SSLStream`. - max_refill_bytes (int): Passed on to :class:`SSLStream`. - Attributes: transport_listener (trio.abc.Listener): The underlying listener that was passed to ``__init__``. @@ -851,12 +850,15 @@ def __init__( ssl_context, *, https_compatible=False, - max_refill_bytes=_default_max_refill_bytes + max_refill_bytes="unused and deprecated", ): + if max_refill_bytes != "unused and deprecated": + warn_deprecated( + "max_refill_bytes=...", "0.12.0", issue=959, instead=None + ) self.transport_listener = transport_listener self._ssl_context = ssl_context self._https_compatible = https_compatible - self._max_refill_bytes = max_refill_bytes async def accept(self): """Accept the next connection and wrap it in an :class:`SSLStream`. @@ -870,7 +872,6 @@ async def accept(self): self._ssl_context, server_side=True, https_compatible=self._https_compatible, - max_refill_bytes=self._max_refill_bytes, ) async def aclose(self): diff --git a/trio/_subprocess.py b/trio/_subprocess.py index ae530e3ae8..15bb98a491 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -478,10 +478,7 @@ async def feed_input(): async def read_output(stream, chunks): async with stream: - while True: - chunk = await stream.receive_some(32768) - if not chunk: - break + async for chunk in stream: chunks.append(chunk) async with trio.open_nursery() as nursery: diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 7cd205e93f..a9f0a98c3b 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -7,6 +7,10 @@ import trio +# XX TODO: is this a good number? who knows... it does match the default Linux +# pipe capacity though. +DEFAULT_RECEIVE_SIZE = 65536 + class _FdHolder: # This class holds onto a raw file descriptor, in non-blocking mode, and @@ -126,13 +130,15 @@ def __init__(self, fd: int): "another task is using this pipe" ) - async def receive_some(self, max_bytes: int) -> bytes: + async def receive_some(self, max_bytes=None) -> bytes: with self._conflict_detector: - if not isinstance(max_bytes, int): - raise TypeError("max_bytes must be integer >= 1") - - if max_bytes < 1: - raise ValueError("max_bytes must be integer >= 1") + if max_bytes is None: + max_bytes = DEFAULT_RECEIVE_SIZE + else: + if not isinstance(max_bytes, int): + raise TypeError("max_bytes must be integer >= 1") + if max_bytes < 1: + raise ValueError("max_bytes must be integer >= 1") await trio.hazmat.checkpoint() while True: diff --git a/trio/_windows_pipes.py b/trio/_windows_pipes.py index 82919b836b..e213e27664 100644 --- a/trio/_windows_pipes.py +++ b/trio/_windows_pipes.py @@ -3,6 +3,9 @@ from ._util import ConflictDetector from ._core._windows_cffi import _handle, raise_winerror, kernel32, ffi +# XX TODO: don't just make this up based on nothing. +DEFAULT_RECEIVE_SIZE = 65536 + # See the comments on _unix_pipes._FdHolder for discussion of why we set the # handle to -1 when it's closed. @@ -86,16 +89,18 @@ def __init__(self, handle: int) -> None: "another task is currently using this pipe" ) - async def receive_some(self, max_bytes: int) -> bytes: + async def receive_some(self, max_bytes=None) -> bytes: with self._conflict_detector: if self._handle_holder.closed: raise _core.ClosedResourceError("this pipe is already closed") - if not isinstance(max_bytes, int): - raise TypeError("max_bytes must be integer >= 1") - - if max_bytes < 1: - raise ValueError("max_bytes must be integer >= 1") + if max_bytes is None: + max_bytes = DEFAULT_RECEIVE_SIZE + else: + if not isinstance(max_bytes, int): + raise TypeError("max_bytes must be integer >= 1") + if max_bytes < 1: + raise ValueError("max_bytes must be integer >= 1") buffer = bytearray(max_bytes) try: diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 91c10798dc..2216692df4 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -67,9 +67,9 @@ async def do_send_all(data): with assert_checkpoints(): assert await s.send_all(data) is None - async def do_receive_some(max_bytes): + async def do_receive_some(*args): with assert_checkpoints(): - return await r.receive_some(1) + return await r.receive_some(*args) async def checked_receive_1(expected): assert await do_receive_some(1) == expected @@ -111,6 +111,13 @@ async def send_empty_then_y(): await r.receive_some(0) with _assert_raises(TypeError): await r.receive_some(1.5) + # it can also be missing or None + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all, b"x") + assert await do_receive_some() == b"x" + async with _core.open_nursery() as nursery: + nursery.start_soon(do_send_all, b"x") + assert await do_receive_some(None) == b"x" with _assert_raises(_core.BusyResourceError): async with _core.open_nursery() as nursery: diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 8457c82f7a..30c868de57 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -225,7 +225,7 @@ def __init__(self, receive_some_hook=None, close_hook=None): self.receive_some_hook = receive_some_hook self.close_hook = close_hook - async def receive_some(self, max_bytes): + async def receive_some(self, max_bytes=None): """Calls the :attr:`receive_some_hook` (if any), and then retrieves data from the internal buffer, blocking if necessary. @@ -235,8 +235,6 @@ async def receive_some(self, max_bytes): with self._conflict_detector: await _core.checkpoint() await _core.checkpoint() - if max_bytes is None: - raise TypeError("max_bytes must not be None") if self._closed: raise _core.ClosedResourceError if self.receive_some_hook is not None: @@ -382,9 +380,9 @@ def memory_stream_pair(): left, right = memory_stream_pair() await left.send_all(b"123") - assert await right.receive_some(10) == b"123" + assert await right.receive_some() == b"123" await right.send_all(b"456") - assert await left.receive_some(10) == b"456" + assert await left.receive_some() == b"456" But if you read the docs for :class:`~trio.StapledStream` and :func:`memory_stream_one_way_pair`, you'll see that all the pieces @@ -411,10 +409,7 @@ async def sender(): await left.send_eof() async def receiver(): - while True: - data = await right.receive_some(10) - if data == b"": - return + async for data in right: print(data) async with trio.open_nursery() as nursery: @@ -508,12 +503,13 @@ async def wait_send_all_might_not_block(self): if self._sender_closed: raise _core.ClosedResourceError - async def receive_some(self, max_bytes): + async def receive_some(self, max_bytes=None): with self._receive_conflict_detector: # Argument validation - max_bytes = operator.index(max_bytes) - if max_bytes < 1: - raise ValueError("max_bytes must be >= 1") + if max_bytes is not None: + max_bytes = operator.index(max_bytes) + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") # State validation if self._receiver_closed: raise _core.ClosedResourceError @@ -528,6 +524,8 @@ async def receive_some(self, max_bytes): raise _core.ClosedResourceError # Get data, possibly waking send_all if self._data: + # Neat trick: if max_bytes is None, then obj[:max_bytes] is + # the same as obj[:]. got = self._data[:max_bytes] del self._data[:max_bytes] self._something_happened() @@ -566,7 +564,7 @@ async def aclose(self): self.close() await _core.checkpoint() - async def receive_some(self, max_bytes): + async def receive_some(self, max_bytes=None): return await self._lbq.receive_some(max_bytes) diff --git a/trio/tests/test_highlevel_generic.py b/trio/tests/test_highlevel_generic.py index 8278b889cd..df2b2cecf7 100644 --- a/trio/tests/test_highlevel_generic.py +++ b/trio/tests/test_highlevel_generic.py @@ -24,7 +24,7 @@ async def aclose(self): class RecordReceiveStream(ReceiveStream): record = attr.ib(factory=list) - async def receive_some(self, max_bytes): + async def receive_some(self, max_bytes=None): self.record.append(("receive_some", max_bytes)) async def aclose(self): diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index 8178f396fd..6a51669648 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -222,8 +222,10 @@ async def send_all(self, data): await self.sleeper("send_all") print(" <-- transport_stream.send_all finished") - async def receive_some(self, nbytes): + async def receive_some(self, nbytes=None): print(" --> transport_stream.receive_some") + if nbytes is None: + nbytes = 65536 # arbitrary with self._receive_some_conflict_detector: try: await _core.checkpoint() @@ -1232,16 +1234,12 @@ async def setup(**kwargs): ################ - # Test https_compatible and max_refill_bytes - _, ssl_listener, ssl_client = await setup( - https_compatible=True, - max_refill_bytes=100, - ) + # Test https_compatible + _, ssl_listener, ssl_client = await setup(https_compatible=True,) ssl_server = await ssl_listener.accept() assert ssl_server._https_compatible - assert ssl_server._max_refill_bytes == 100 await aclose_forcefully(ssl_listener) await aclose_forcefully(ssl_client) diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index 8822982f04..fe55b85aba 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -120,11 +120,8 @@ async def feed_input(): async def check_output(stream, expected): seen = bytearray() - while True: - chunk = await stream.receive_some(4096) - if not chunk: - break - seen.extend(chunk) + async for chunk in stream: + seen += chunk assert seen == expected async with _core.open_nursery() as nursery: diff --git a/trio/tests/test_testing.py b/trio/tests/test_testing.py index d112e66d2e..e73624a67c 100644 --- a/trio/tests/test_testing.py +++ b/trio/tests/test_testing.py @@ -603,8 +603,8 @@ async def do_receive_some(max_bytes): mrs.put_data(b"abc") assert await do_receive_some(1) == b"a" assert await do_receive_some(10) == b"bc" - with pytest.raises(TypeError): - await do_receive_some(None) + mrs.put_data(b"abc") + assert await do_receive_some(None) == b"abc" with pytest.raises(_core.BusyResourceError): async with _core.open_nursery() as nursery: From aab5fe32499ce2f39d7e33a34a082d61bac4200c Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Tue, 25 Jun 2019 16:39:25 -0700 Subject: [PATCH 2/5] Attempt to fix 3.5 compat --- trio/_abc.py | 1 + trio/_ssl.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index da381b2ed2..b9868e455a 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -413,6 +413,7 @@ async def receive_some(self, max_bytes=None): """ + @aiter_compat def __aiter__(self): return self diff --git a/trio/_ssl.py b/trio/_ssl.py index 16d9623243..d0e63d92f1 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -307,7 +307,7 @@ def __init__( server_hostname=None, server_side=False, https_compatible=False, - max_refill_bytes="unused and deprecated", + max_refill_bytes="unused and deprecated" ): self.transport_stream = transport_stream self._state = _State.OK @@ -850,7 +850,7 @@ def __init__( ssl_context, *, https_compatible=False, - max_refill_bytes="unused and deprecated", + max_refill_bytes="unused and deprecated" ): if max_refill_bytes != "unused and deprecated": warn_deprecated( From ccf637df5301367ee8cbdd8b4fcbcda75ed9c9e4 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Tue, 25 Jun 2019 17:03:06 -0700 Subject: [PATCH 3/5] Add deprecation test to get coverage up --- trio/tests/test_ssl.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index 6a51669648..62081b0b09 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -1244,3 +1244,15 @@ async def setup(**kwargs): await aclose_forcefully(ssl_listener) await aclose_forcefully(ssl_client) await aclose_forcefully(ssl_server) + + +async def test_deprecated_max_refill_bytes(): + stream1, stream2 = memory_stream_pair() + with pytest.warns(trio.TrioDeprecationWarning): + SSLStream(stream1, CLIENT_CTX, max_refill_bytes=100) + with pytest.warns(trio.TrioDeprecationWarning): + # passing None is wrong here, but I'm too lazy to make a fake Listener + # and we get away with it for now. And this test will be deleted in a + # release or two anyway, so hopefully we'll keep getting away with it + # for long enough. + SSLListener(None, CLIENT_CTX, max_refill_bytes=100) From 1c169472c4c41b4ba2d0c66427bdcc8262d65f6d Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 5 Jul 2019 03:35:08 -0700 Subject: [PATCH 4/5] Respond to review feedback --- docs/source/tutorial/echo-server.py | 1 - newsfragments/959.feature.rst | 14 ++++++++------ trio/_abc.py | 4 +++- trio/tests/test_ssl.py | 2 +- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/docs/source/tutorial/echo-server.py b/docs/source/tutorial/echo-server.py index 08cac7a815..29683cafea 100644 --- a/docs/source/tutorial/echo-server.py +++ b/docs/source/tutorial/echo-server.py @@ -21,7 +21,6 @@ async def echo_server(server_stream): print("echo_server {}: received data {!r}".format(ident, data)) await server_stream.send_all(data) print("echo_server {}: connection closed".format(ident)) - return # FIXME: add discussion of MultiErrors to the tutorial, and use # MultiError.catch here. (Not important in this case, but important if the # server code uses nurseries internally.) diff --git a/newsfragments/959.feature.rst b/newsfragments/959.feature.rst index 4e5bfc97cd..0e96ba35c4 100644 --- a/newsfragments/959.feature.rst +++ b/newsfragments/959.feature.rst @@ -1,8 +1,10 @@ If you have a `~trio.abc.ReceiveStream` object, you can now use ``async for data in stream: ...`` instead of calling -`~trio.abc.ReceiveStream.receive_some` repeatedly. And the best part -is, it automatically checks for EOF for you, so you don't have to. -Also, you no longer have to choose a magic buffer size value before -calling `~trio.abc.ReceiveStream.receive_some`; you can now call -``await stream.receive_some()`` and the stream will automatically pick -a reasonable value for you. +`~trio.abc.ReceiveStream.receive_some`. Each iteration gives an +arbitrary sized chunk of bytes. And the best part is, the loop +automatically exits when you reach EOF, so you don't have to check for +it yourself anymore. Relatedly, you no longer need to pick a magic +buffer size value before calling +`~trio.abc.ReceiveStream.receive_some`; you can ``await +stream.receive_some()`` with no arguments, and the stream will +automatically pick a reasonable size for you. diff --git a/trio/_abc.py b/trio/_abc.py index b9868e455a..b8f521e23b 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -379,7 +379,9 @@ class ReceiveStream(AsyncResource): :class:`ReceiveChannel`. `ReceiveStream` objects can be used in ``async for`` loops. Each iteration - will produce an arbitrary size + will produce an arbitrary sized chunk of bytes, like calling + `receive_some` with no arguments. Every chunk will contain at least one + byte, and the loop automatically exits when reaching end-of-file. """ __slots__ = () diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index 62081b0b09..78e09c35ac 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -1235,7 +1235,7 @@ async def setup(**kwargs): ################ # Test https_compatible - _, ssl_listener, ssl_client = await setup(https_compatible=True,) + _, ssl_listener, ssl_client = await setup(https_compatible=True) ssl_server = await ssl_listener.accept() From 754fd302259eb1e29eb2a0bb9e7e155069658d81 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 5 Jul 2019 03:35:17 -0700 Subject: [PATCH 5/5] Rework SSLStream's receive size handling --- trio/_ssl.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/trio/_ssl.py b/trio/_ssl.py index d0e63d92f1..f576af8c0f 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -165,9 +165,30 @@ # SSLStream ################################################################ -# XX TODO: this number was pulled out of a hat. We should tune it with -# science. -DEFAULT_RECEIVE_SIZE = 65536 +# Ideally, when the user calls SSLStream.receive_some() with no argument, then +# we should do exactly one call to self.transport_stream.receive_some(), +# decrypt everything we got, and return it. Unfortunately, the way openssl's +# API works, we have to pick how much data we want to allow when we call +# read(), and then it (potentially) triggers a call to +# transport_stream.receive_some(). So at the time we pick the amount of data +# to decrypt, we don't know how much data we've read. As a simple heuristic, +# we record the max amount of data returned by previous calls to +# transport_stream.receive_some(), and we use that for future calls to read(). +# But what do we use for the very first call? That's what this constant sets. +# +# Note that the value passed to read() is a limit on the amount of +# *decrypted* data, but we can only see the size of the *encrypted* data +# returned by transport_stream.receive_some(). TLS adds a small amount of +# framing overhead, and TLS compression is rarely used these days because it's +# insecure. So the size of the encrypted data should be a slight over-estimate +# of the size of the decrypted data, which is exactly what we want. +# +# The specific value is not really based on anything; it might be worth tuning +# at some point. But, if you have an TCP connection with the typical 1500 byte +# MTU and an initial window of 10 (see RFC 6928), then the initial burst of +# data will be limited to ~15000 bytes (or a bit less due to IP-level framing +# overhead), so this is chosen to be larger than that. +STARTING_RECEIVE_SIZE = 16384 class NeedHandshakeError(Exception): @@ -342,6 +363,8 @@ def __init__( "another task is currently receiving data on this SSLStream" ) + self._estimated_receive_size = STARTING_RECEIVE_SIZE + _forwarded = { "context", "server_side", @@ -537,6 +560,9 @@ async def _retry(self, fn, *args, ignore_want_read=False): if not data: self._incoming.write_eof() else: + self._estimated_receive_size = max( + self._estimated_receive_size, len(data) + ) self._incoming.write(data) self._inner_recv_count += 1 if not yielded: @@ -617,10 +643,12 @@ async def receive_some(self, max_bytes=None): else: raise if max_bytes is None: - # Heuristic: normally we use DEFAULT_RECEIVE_SIZE, but if - # the transport gave us a bunch of data last time then we'll - # try to decrypt and pass it all back at once. - max_bytes = max(DEFAULT_RECEIVE_SIZE, self._incoming.pending) + # If we somehow have more data already in our pending buffer + # than the estimate receive size, bump up our size a bit for + # this read only. + max_bytes = max( + self._estimated_receive_size, self._incoming.pending + ) else: max_bytes = _operator.index(max_bytes) if max_bytes < 1: