Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streams are iterable + receive_some doesn't require an explicit size #1123

Merged
merged 6 commits into from
Jul 30, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Abstract base classes
* - :class:`ReceiveStream`
- :class:`AsyncResource`
- :meth:`~ReceiveStream.receive_some`
-
- ``__aiter__``, ``__anext__``
- :class:`~trio.testing.MemoryReceiveStream`
* - :class:`Stream`
- :class:`SendStream`, :class:`ReceiveStream`
Expand Down
39 changes: 21 additions & 18 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -908,12 +908,10 @@ And the second task's job is to process the data the server sends back:
:lineno-match:
:pyobject: receiver

It repeatedly calls ``await client_stream.receive_some(...)`` to get
more data from the server (again, all Trio streams provide this
method), and then checks to see if the server has closed the
connection. ``receive_some`` only returns an empty bytestring if the
connection has been closed; otherwise, it waits until some data has
arrived, up to a maximum of ``BUFSIZE`` bytes.
It uses an ``async for`` loop to fetch data from the server.
Alternatively, it could use `~trio.abc.ReceiveStream.receive_some`,
which is the opposite of `~trio.abc.SendStream.send_all`, but using
``async for`` saves some boilerplate.

And now we're ready to look at the server.

Expand Down Expand Up @@ -974,11 +972,11 @@ functions we saw in the last section:

The argument ``server_stream`` is provided by :func:`serve_tcp`, and
is the other end of the connection we made in the client: so the data
that the client passes to ``send_all`` will come out of
``receive_some`` here, and vice-versa. Then we have a ``try`` block
discussed below, and finally the server loop which alternates between
reading some data from the socket and then sending it back out again
(unless the socket was closed, in which case we quit).
that the client passes to ``send_all`` will come out here. Then we
have a ``try`` block discussed below, and finally the server loop
which alternates between reading some data from the socket and then
sending it back out again (unless the socket was closed, in which case
we quit).

So what's that ``try`` block for? Remember that in Trio, like Python
in general, exceptions keep propagating until they're caught. Here we
Expand Down Expand Up @@ -1029,7 +1027,7 @@ our client could use a single task like::
while True:
data = ...
await client_stream.send_all(data)
received = await client_stream.receive_some(BUFSIZE)
received = await client_stream.receive_some()
if not received:
sys.exit()
await trio.sleep(1)
Expand All @@ -1046,18 +1044,23 @@ line, any time we're expecting more than one byte of data, we have to
be prepared to call ``receive_some`` multiple times.

And where this would go especially wrong is if we find ourselves in
the situation where ``len(data) > BUFSIZE``. On each pass through the
loop, we send ``len(data)`` bytes, but only read *at most* ``BUFSIZE``
bytes. The result is something like a memory leak: we'll end up with
more and more data backed up in the network, until eventually
something breaks.
the situation where ``data`` is big enough that it passes some
internal threshold, and the operating system or network decide to
always break it up into multiple pieces. Now on each pass through the
loop, we send ``len(data)`` bytes, but read less than that. The result
is something like a memory leak: we'll end up with more and more data
backed up in the network, until eventually something breaks.

.. note:: If you're curious *how* things break, then you can use
`~trio.abc.ReceiveStream.receive_some`\'s optional argument to put
a limit on how many bytes you read each time, and see what happens.

We could fix this by keeping track of how much data we're expecting at
each moment, and then keep calling ``receive_some`` until we get it all::

expected = len(data)
while expected > 0:
received = await client_stream.receive_some(BUFSIZE)
received = await client_stream.receive_some(expected)
if not received:
sys.exit(1)
expected -= len(received)
Expand Down
11 changes: 3 additions & 8 deletions docs/source/tutorial/echo-client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
# - can't be in use by some other program on your computer
# - must match what we set in our echo server
PORT = 12345
# How much memory to spend (at most) on each call to recv. Pretty arbitrary,
# but shouldn't be too big or too small.
BUFSIZE = 16384

async def sender(client_stream):
print("sender: started!")
Expand All @@ -22,12 +19,10 @@ async def sender(client_stream):

async def receiver(client_stream):
print("receiver: started!")
while True:
data = await client_stream.receive_some(BUFSIZE)
async for data in client_stream:
print("receiver: got data {!r}".format(data))
if not data:
print("receiver: connection closed")
sys.exit()
print("receiver: connection closed")
sys.exit()

async def parent():
print("parent: connecting to 127.0.0.1:{}".format(PORT))
Expand Down
12 changes: 3 additions & 9 deletions docs/source/tutorial/echo-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
# - can't be in use by some other program on your computer
# - must match what we set in our echo client
PORT = 12345
# How much memory to spend (at most) on each call to recv. Pretty arbitrary,
# but shouldn't be too big or too small.
BUFSIZE = 16384

CONNECTION_COUNTER = count()

Expand All @@ -20,14 +17,11 @@ async def echo_server(server_stream):
ident = next(CONNECTION_COUNTER)
print("echo_server {}: started".format(ident))
try:
while True:
data = await server_stream.receive_some(BUFSIZE)
async for data in server_stream:
print("echo_server {}: received data {!r}".format(ident, data))
if not data:
print("echo_server {}: connection closed".format(ident))
return
print("echo_server {}: sending data {!r}".format(ident, data))
await server_stream.send_all(data)
print("echo_server {}: connection closed".format(ident))
return
njsmith marked this conversation as resolved.
Show resolved Hide resolved
# FIXME: add discussion of MultiErrors to the tutorial, and use
# MultiError.catch here. (Not important in this case, but important if the
# server code uses nurseries internally.)
Expand Down
8 changes: 8 additions & 0 deletions newsfragments/959.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
If you have a `~trio.abc.ReceiveStream` object, you can now use
``async for data in stream: ...`` instead of calling
`~trio.abc.ReceiveStream.receive_some` repeatedly. And the best part
is, it automatically checks for EOF for you, so you don't have to.
njsmith marked this conversation as resolved.
Show resolved Hide resolved
Also, you no longer have to choose a magic buffer size value before
calling `~trio.abc.ReceiveStream.receive_some`; you can now call
``await stream.receive_some()`` and the stream will automatically pick
njsmith marked this conversation as resolved.
Show resolved Hide resolved
a reasonable value for you.
2 changes: 1 addition & 1 deletion notes-to-self/graceful-shutdown-idea.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def shutting_down(self):
async def stream_handler(stream):
while True:
with gsm.cancel_on_graceful_shutdown():
data = await stream.receive_some(...)
data = await stream.receive_some()
if gsm.shutting_down:
break

Expand Down
24 changes: 17 additions & 7 deletions trio/_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,26 +378,26 @@ class ReceiveStream(AsyncResource):
If you want to receive Python objects rather than raw bytes, see
:class:`ReceiveChannel`.

`ReceiveStream` objects can be used in ``async for`` loops. Each iteration
will produce an arbitrary size
njsmith marked this conversation as resolved.
Show resolved Hide resolved

"""
__slots__ = ()

@abstractmethod
async def receive_some(self, max_bytes):
async def receive_some(self, max_bytes=None):
"""Wait until there is data available on this stream, and then return
at most ``max_bytes`` of it.
some of it.

A return value of ``b""`` (an empty bytestring) indicates that the
stream has reached end-of-file. Implementations should be careful that
they return ``b""`` if, and only if, the stream has reached
end-of-file!

This method will return as soon as any data is available, so it may
return fewer than ``max_bytes`` of data. But it will never return
more.

Args:
max_bytes (int): The maximum number of bytes to return. Must be
greater than zero.
greater than zero. Optional; if omitted, then the stream object
is free to pick a reasonable default.

Returns:
bytes or bytearray: The data received.
Expand All @@ -413,6 +413,16 @@ async def receive_some(self, max_bytes):

"""

@aiter_compat
def __aiter__(self):
return self

async def __anext__(self):
data = await self.receive_some()
if not data:
raise StopAsyncIteration
return data


class Stream(SendStream, ReceiveStream):
"""A standard interface for interacting with bidirectional byte streams.
Expand Down
4 changes: 2 additions & 2 deletions trio/_highlevel_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class StapledStream(HalfCloseableStream):
left, right = trio.testing.memory_stream_pair()
echo_stream = StapledStream(SocketStream(left), SocketStream(right))
await echo_stream.send_all(b"x")
assert await echo_stream.receive_some(1) == b"x"
assert await echo_stream.receive_some() == b"x"

:class:`StapledStream` objects implement the methods in the
:class:`~trio.abc.HalfCloseableStream` interface. They also have two
Expand Down Expand Up @@ -96,7 +96,7 @@ async def send_eof(self):
else:
return await self.send_stream.aclose()

async def receive_some(self, max_bytes):
async def receive_some(self, max_bytes=None):
"""Calls ``self.receive_stream.receive_some``.

"""
Expand Down
10 changes: 9 additions & 1 deletion trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

__all__ = ["SocketStream", "SocketListener"]

# XX TODO: this number was picked arbitrarily. We should do experiments to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One wrinkle: AFAIK, each call to socket.recv() allocates a new bytes object that is large enough for the entire given chunksize. If large allocations are more expensive, passing a too-large buffer is probably bad for performance. (The allocators I know of use 128KB as their threshold for "this is big, mmap it instead of finding a free chunk" but if one used 64KB instead and we got a mmap/munmap pair on each receive, that feels maybe bad?)

My intuition favors a much lower buffer size, like 4KB or 8KB, but I also do most of my work on systems that are rarely backlogged, so my intuition might well be off when it comes to a high-throughput Trio application.

Another option we could consider: the socket owns a receive buffer (bytearray) which it reuses, calls recv_into(), and extracts just the amount actually received into a bytes for returning. Downside: spends 64KB (or whatever) per socket in steady state. Counterpoint: the OS-level socket buffers are probably much larger than that (but I don't know how much memory they occupy when the socket isn't backlogged).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting discussion but I don't want it to hold up merging the basic functionality, so I split it off into #1139

(Twisted has apparently used 64 KiB receive buffers for its entire existence and I can't find any evidence that anyone has ever thought twice about it. So we're probably not risking any disaster by starting with 64 KiB for now :-).)

# tune it. (Or make it dynamic -- one idea is to start small and increase it
# if we observe single reads filling up the whole buffer, at least within some
# limits.)
DEFAULT_RECEIVE_SIZE = 65536

_closed_stream_errnos = {
# Unix
errno.EBADF,
Expand Down Expand Up @@ -129,7 +135,9 @@ async def send_eof(self):
with _translate_socket_errors_to_stream_errors():
self.socket.shutdown(tsocket.SHUT_WR)

async def receive_some(self, max_bytes):
async def receive_some(self, max_bytes=None):
if max_bytes is None:
max_bytes = DEFAULT_RECEIVE_SIZE
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
with _translate_socket_errors_to_stream_errors():
Expand Down
51 changes: 26 additions & 25 deletions trio/_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,16 @@
from ._highlevel_generic import aclose_forcefully
from . import _sync
from ._util import ConflictDetector
from ._deprecate import warn_deprecated

################################################################
# SSLStream
################################################################

# XX TODO: this number was pulled out of a hat. We should tune it with
# science.
DEFAULT_RECEIVE_SIZE = 65536


class NeedHandshakeError(Exception):
"""Some :class:`SSLStream` methods can't return any meaningful data until
Expand Down Expand Up @@ -197,8 +202,6 @@ def done(self):

_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])

_default_max_refill_bytes = 32 * 1024


class SSLStream(Stream):
r"""Encrypted communication using SSL/TLS.
Expand Down Expand Up @@ -269,15 +272,6 @@ class SSLStream(Stream):
that :class:`~ssl.SSLSocket` implements the
``https_compatible=True`` behavior by default.

max_refill_bytes (int): :class:`~ssl.SSLSocket` maintains an internal
buffer of incoming data, and when it runs low then it calls
:meth:`receive_some` on the underlying transport stream to refill
it. This argument lets you set the ``max_bytes`` argument passed to
the *underlying* :meth:`receive_some` call. It doesn't affect calls
to *this* class's :meth:`receive_some`, or really anything else
user-observable except possibly performance. You probably don't need
to worry about this.

Attributes:
transport_stream (trio.abc.Stream): The underlying transport stream
that was passed to ``__init__``. An example of when this would be
Expand Down Expand Up @@ -313,11 +307,14 @@ def __init__(
server_hostname=None,
server_side=False,
https_compatible=False,
max_refill_bytes=_default_max_refill_bytes
max_refill_bytes="unused and deprecated"
):
self.transport_stream = transport_stream
self._state = _State.OK
self._max_refill_bytes = max_refill_bytes
if max_refill_bytes != "unused and deprecated":
warn_deprecated(
"max_refill_bytes=...", "0.12.0", issue=959, instead=None
)
self._https_compatible = https_compatible
self._outgoing = _stdlib_ssl.MemoryBIO()
self._incoming = _stdlib_ssl.MemoryBIO()
Expand Down Expand Up @@ -536,9 +533,7 @@ async def _retry(self, fn, *args, ignore_want_read=False):
async with self._inner_recv_lock:
yielded = True
if recv_count == self._inner_recv_count:
data = await self.transport_stream.receive_some(
self._max_refill_bytes
)
data = await self.transport_stream.receive_some()
if not data:
self._incoming.write_eof()
else:
Expand Down Expand Up @@ -590,7 +585,7 @@ async def do_handshake(self):
# https://bugs.python.org/issue30141
# So we *definitely* have to make sure that do_handshake is called
# before doing anything else.
async def receive_some(self, max_bytes):
async def receive_some(self, max_bytes=None):
"""Read some data from the underlying transport, decrypt it, and
return it.

Expand Down Expand Up @@ -621,9 +616,15 @@ async def receive_some(self, max_bytes):
return b""
else:
raise
max_bytes = _operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
if max_bytes is None:
# Heuristic: normally we use DEFAULT_RECEIVE_SIZE, but if
# the transport gave us a bunch of data last time then we'll
# try to decrypt and pass it all back at once.
max_bytes = max(DEFAULT_RECEIVE_SIZE, self._incoming.pending)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused at what the benefit is of having a DEFAULT_RECEIVE_SIZE for SSLStream at all. It seems like we could instead have a nice magic-number-free policy of "ask the transport stream to receive_some() with no size specified, then return all the data we decrypted from whatever we got in that chunk, or loop and receive_some() again if we didn't get any decrypted data".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is complicated... what you say makes logical sense, but, openssl's API is super awkward. There isn't any way to say "please decrypt all the data in your receive buffer". You have to pick a value to pass to SSLObject.read. And even more annoying: you don't find out until after you've picked a value whether you have to go back to the underlying transport for more data. So you have to pick the value before you know how much data the underlying transport wants to give you. And once you've picked a value, you have to keep using that value until some data is returned.

So my logic was: well, if we already have a bunch of data in the receive buffer because the underlying transport was generous, then likely we can just decrypt and return that, and the size of the encrypted data is a plausible upper bound on the size of the decrypted data, so self._incoming.pending is a good value to pass to SSLObject.read.

But, sometimes there won't be a lot of data in the receive buffer – for example, because our heuristic worked well the previous time, and cleared everything out, or almost everything. Like, imagine there's 1 byte left in the receive buffer. The way TLS works, you generally can't decrypt just 1 byte – everything's transmitted in frames, and you need to get the whole frame with its header and MAC and everything before you can decrypt any of it. So if we call ssl_object.read(1), then openssl will end up requesting another large chunk of data from the underlying transport, then our read(1) call will decrypt the first byte and return it, leaving the rest of the data sitting in the buffer for next time. And that would be unfortunate.

So my first attempt at a heuristic is: use the receive buffer size, but never anything smaller than DEFAULT_RECEIVE_SIZE.

I guess this has a weird effect if the underlying transport likes to return more than DEFAULT_RECEIVE_SIZE. Say it gives us 65 KiB, while DEFAULT_RECEIVE_SIZE is 64 KiB. On our first call to SSLStream.receive_some, the buffer size is zero, so we do read(64 KiB). This drains 65 KiB from the underlying transport, then decrypts and returns the first 64 KiB. The next time we call SSLStream.receive_some, we do read(64 KiB) again, but there's already 1 KiB of data in the buffer, so we just return that immediately without refilling the buffer. Then this repeats indefinitely, so we alternate between doing a big receive and a small receive every time. Seems wasteful – it'd be better to return 65 KiB each time.

So maybe a better strategy would be to start with some smallish default receive size, and then increase it over time if we observe the underlying transport giving us more data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rewrote the SSLStream stuff to hopefully address the above issues...

else:
max_bytes = _operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
try:
return await self._retry(self._ssl_object.read, max_bytes)
except trio.BrokenResourceError as exc:
Expand Down Expand Up @@ -837,8 +838,6 @@ class SSLListener(Listener[SSLStream]):

https_compatible (bool): Passed on to :class:`SSLStream`.

max_refill_bytes (int): Passed on to :class:`SSLStream`.

Attributes:
transport_listener (trio.abc.Listener): The underlying listener that was
passed to ``__init__``.
Expand All @@ -851,12 +850,15 @@ def __init__(
ssl_context,
*,
https_compatible=False,
max_refill_bytes=_default_max_refill_bytes
max_refill_bytes="unused and deprecated"
):
if max_refill_bytes != "unused and deprecated":
warn_deprecated(
"max_refill_bytes=...", "0.12.0", issue=959, instead=None
)
self.transport_listener = transport_listener
self._ssl_context = ssl_context
self._https_compatible = https_compatible
self._max_refill_bytes = max_refill_bytes

async def accept(self):
"""Accept the next connection and wrap it in an :class:`SSLStream`.
Expand All @@ -870,7 +872,6 @@ async def accept(self):
self._ssl_context,
server_side=True,
https_compatible=self._https_compatible,
max_refill_bytes=self._max_refill_bytes,
)

async def aclose(self):
Expand Down
5 changes: 1 addition & 4 deletions trio/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,7 @@ async def feed_input():

async def read_output(stream, chunks):
async with stream:
while True:
chunk = await stream.receive_some(32768)
if not chunk:
break
async for chunk in stream:
chunks.append(chunk)

async with trio.open_nursery() as nursery:
Expand Down
Loading