diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 8ebd8ed5c8d..a90a9dfd1ce 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -14,7 +14,7 @@ CONTENT_TRANSFER_ENCODING, CONTENT_TYPE) from .helpers import CHAR, TOKEN, parse_mimetype, reify from .http import HttpParser -from .payload import (JsonPayload, LookupError, Payload, StringPayload, +from .payload import (JsonPayload, LookupError, Order, Payload, StringPayload, get_payload, payload_type) @@ -434,7 +434,7 @@ def filename(self): return content_disposition_filename(params, 'filename') -@payload_type(BodyPartReader) +@payload_type(BodyPartReader, order=Order.try_first) class BodyPartReaderPayload(Payload): def __init__(self, value, *args, **kwargs): diff --git a/aiohttp/payload.py b/aiohttp/payload.py index a43e6379864..e3161961dd0 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -1,9 +1,12 @@ +import enum import io import json import mimetypes import os import warnings from abc import ABC, abstractmethod +from collections.abc import AsyncIterable +from itertools import chain from multidict import CIMultiDict @@ -16,30 +19,38 @@ __all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload', 'BytesPayload', 'StringPayload', 'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload', - 'TextIOPayload', 'StringIOPayload', 'JsonPayload') + 'TextIOPayload', 'StringIOPayload', 'JsonPayload', + 'AsyncIterablePayload') -TOO_LARGE_BYTES_BODY = 2 ** 20 +TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB class LookupError(Exception): pass +class Order(enum.Enum): + normal = 'normal' + try_first = 'try_first' + try_last = 'try_last' + + def get_payload(data, *args, **kwargs): return PAYLOAD_REGISTRY.get(data, *args, **kwargs) -def register_payload(factory, type): - PAYLOAD_REGISTRY.register(factory, type) +def register_payload(factory, type, *, order=Order.normal): + PAYLOAD_REGISTRY.register(factory, type, order=order) class payload_type: - def __init__(self, type): + def __init__(self, type, *, order=Order.normal): self.type = type + self.order = order def __call__(self, factory): - register_payload(factory, self.type) + register_payload(factory, self.type, order=self.order) return factory @@ -50,19 +61,28 @@ class PayloadRegistry: """ def __init__(self): - self._registry = [] + self._first = [] + self._normal = [] + self._last = [] - def get(self, data, *args, **kwargs): + def get(self, data, *args, _CHAIN=chain, **kwargs): if isinstance(data, Payload): return data - for factory, type in self._registry: + for factory, type in _CHAIN(self._first, self._normal, self._last): if isinstance(data, type): return factory(data, *args, **kwargs) raise LookupError() - def register(self, factory, type): - self._registry.append((factory, type)) + def register(self, factory, type, *, order=Order.normal): + if order is Order.try_first: + self._first.append((factory, type)) + elif order is Order.normal: + self._normal.append((factory, type)) + elif order is Order.try_last: + self._last.append((factory, type)) + else: + raise ValueError("Unsupported order {!r}".format(order)) class Payload(ABC): @@ -136,8 +156,9 @@ async def write(self, writer): class BytesPayload(Payload): def __init__(self, value, *args, **kwargs): - assert isinstance(value, (bytes, bytearray, memoryview)), \ - "value argument must be byte-ish (%r)" % type(value) + if not isinstance(value, (bytes, bytearray, memoryview)): + raise TypeError("value argument must be byte-ish, not (!r)" + .format(type(value))) if 'content_type' not in kwargs: kwargs['content_type'] = 'application/octet-stream' @@ -278,6 +299,32 @@ def __init__(self, value, content_type=content_type, encoding=encoding, *args, **kwargs) +class AsyncIterablePayload(Payload): + + def __init__(self, value, *args, **kwargs): + if not isinstance(value, AsyncIterable): + raise TypeError("value argument must support " + "collections.abc.AsyncIterablebe interface, " + "got {!r}".format(type(value))) + + if 'content_type' not in kwargs: + kwargs['content_type'] = 'application/octet-stream' + + super().__init__(value, *args, **kwargs) + + self._iter = value.__aiter__() + + async def write(self, writer): + try: + # iter is not None check prevents rare cases + # when the case iterable is used twice + while True: + chunk = await self._iter.__anext__() + await writer.write(chunk) + except StopAsyncIteration: + self._iter = None + + PAYLOAD_REGISTRY = PayloadRegistry() PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview)) PAYLOAD_REGISTRY.register(StringPayload, str) @@ -287,3 +334,7 @@ def __init__(self, value, PAYLOAD_REGISTRY.register( BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) +# try_last for giving a chance to more specialized async interables like +# multidict.BodyPartReaderPayload override the default +PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, + order=Order.try_last) diff --git a/aiohttp/payload_streamer.py b/aiohttp/payload_streamer.py index 3b45888aeaa..c400aa013c8 100644 --- a/aiohttp/payload_streamer.py +++ b/aiohttp/payload_streamer.py @@ -22,6 +22,7 @@ async def file_sender(writer, file_name=None): """ import asyncio +import warnings from .payload import Payload, payload_type @@ -43,6 +44,9 @@ async def __call__(self, writer): class streamer: def __init__(self, coro): + warnings.warn("@streamer is deprecated, use async generators instead", + DeprecationWarning, + stacklevel=2) self.coro = coro def __call__(self, *args, **kwargs): diff --git a/docs/client_quickstart.rst b/docs/client_quickstart.rst index bd4e1c5967a..4a28f6f8c06 100644 --- a/docs/client_quickstart.rst +++ b/docs/client_quickstart.rst @@ -1,7 +1,8 @@ .. _aiohttp-client-quickstart: -Client Quickstart -================= +=================== + Client Quickstart +=================== .. currentmodule:: aiohttp @@ -16,7 +17,7 @@ Let's get started with some simple examples. Make a Request --------------- +============== Begin by importing the aiohttp module:: @@ -62,7 +63,7 @@ Other HTTP methods are available as well:: Passing Parameters In URLs --------------------------- +========================== You often want to send some sort of data in the URL's query string. If you were constructing the URL by hand, this data would be given as key/value @@ -123,7 +124,7 @@ is not encoded by library. Note that ``+`` is not encoded:: Passing *params* overrides ``encoded=True``, never use both options. Response Content and Status Code --------------------------------- +================================ We can read the content of the server's response and it's status code. Consider the GitHub time-line again:: @@ -144,7 +145,7 @@ specify custom encoding for the :meth:`~ClientResponse.text` method:: Binary Response Content ------------------------ +======================= You can also access the response body as bytes, for non-text requests:: @@ -161,7 +162,7 @@ You can enable ``brotli`` transfer-encodings support, just install `brotlipy `_. JSON Request ------------- +============ Any of session's request methods like :func:`request`, :meth:`ClientSession.get`, :meth:`ClientSesssion.post` etc. accept @@ -188,7 +189,7 @@ parameter:: incompatible. JSON Response Content ---------------------- +===================== There's also a built-in JSON decoder, in case you're dealing with JSON data:: @@ -207,7 +208,7 @@ decoder functions for the :meth:`~ClientResponse.json` call. Streaming Response Content --------------------------- +========================== While methods :meth:`~ClientResponse.read`, :meth:`~ClientResponse.json` and :meth:`~ClientResponse.text` are very @@ -237,7 +238,7 @@ It is not possible to use :meth:`~ClientResponse.read`, explicit reading from :attr:`~ClientResponse.content`. More complicated POST requests ------------------------------- +============================== Typically, you want to send some form-encoded data -- much like an HTML form. To do this, simply pass a dictionary to the *data* argument. Your @@ -278,7 +279,7 @@ To send text with appropriate content-type just use ``text`` attribute :: ... POST a Multipart-Encoded File ------------------------------ +============================= To upload Multipart-encoded files:: @@ -306,7 +307,7 @@ for supported format information. Streaming uploads ------------------ +================= :mod:`aiohttp` supports multiple types of streaming uploads, which allows you to send large files without reading them into memory. @@ -317,15 +318,14 @@ As a simple case, simply provide a file-like object for your body:: await session.post('http://httpbin.org/post', data=f) -Or you can use :class:`aiohttp.streamer` decorator:: +Or you can use *asynchronous generator*:: - @aiohttp.streamer - async def file_sender(writer, file_name=None): - with open(file_name, 'rb') as f: - chunk = f.read(2**16) + async def file_sender(file_name=None): + async with aiofiles.open(file_name, 'rb') as f: + chunk = await f.read(64*1024) while chunk: - await writer.write(chunk) - chunk = f.read(2**16) + yield chunk + chunk = await f.read(64*1024) # Then you can use file_sender as a data provider: @@ -333,12 +333,23 @@ Or you can use :class:`aiohttp.streamer` decorator:: data=file_sender(file_name='huge_file')) as resp: print(await resp.text()) +.. note:: + + Python 3.5 has no support for asynchronous generators, use + ``async_generator`` library as workaround. + +.. deprecated:: 3.1 + + ``aiohttp`` still supports ``aiohttp.streamer`` decorator but this + approach is deprecated in favor of *asynchronous generators* as + shown above. + .. _aiohttp-client-websockets: WebSockets ----------- +========== :mod:`aiohttp` works with client websockets out-of-the-box. @@ -372,7 +383,7 @@ multiple writer tasks which can only send data asynchronously (by Timeouts --------- +======== By default all IO operations have 5min timeout. The timeout may be overridden by passing ``timeout`` parameter into diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 92e7eeb98f5..45ce69da7e9 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -331,7 +331,8 @@ The client session supports the context manager protocol for self closing. .. versionadded:: 2.3 :param trace_request_ctx: Object used to give as a kw param for each new - :class:`TraceConfig` object instantiated, used to give information to the + :class:`TraceConfig` object instantiated, + used to give information to the tracers that is only available at request time. .. versionadded:: 3.0 diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 4faea7514fb..1124e55e3f6 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -10,6 +10,7 @@ from unittest import mock import pytest +from async_generator import async_generator, yield_ from multidict import MultiDict import aiohttp @@ -1483,13 +1484,14 @@ async def handler(request): with fname.open('rb') as f: data_size = len(f.read()) - @aiohttp.streamer - async def stream(writer, fname): - with fname.open('rb') as f: - data = f.read(100) - while data: - await writer.write(data) + with pytest.warns(DeprecationWarning): + @aiohttp.streamer + async def stream(writer, fname): + with fname.open('rb') as f: data = f.read(100) + while data: + await writer.write(data) + data = f.read(100) resp = await client.post( '/', data=stream(fname), headers={'Content-Length': str(data_size)}) @@ -1516,13 +1518,14 @@ async def handler(request): with fname.open('rb') as f: data_size = len(f.read()) - @aiohttp.streamer - async def stream(writer): - with fname.open('rb') as f: - data = f.read(100) - while data: - await writer.write(data) + with pytest.warns(DeprecationWarning): + @aiohttp.streamer + async def stream(writer): + with fname.open('rb') as f: data = f.read(100) + while data: + await writer.write(data) + data = f.read(100) resp = await client.post( '/', data=stream, headers={'Content-Length': str(data_size)}) @@ -2523,3 +2526,24 @@ async def canceller(): fut2.cancel() await asyncio.gather(fetch1(), fetch2(), canceller()) + + +async def test_async_payload_generator(aiohttp_client): + + async def handler(request): + data = await request.read() + assert data == b'1234567890' * 100 + return web.Response() + + app = web.Application() + app.add_routes([web.post('/', handler)]) + + client = await aiohttp_client(app) + + @async_generator + async def gen(): + for i in range(100): + await yield_(b'1234567890') + + resp = await client.post('/', data=gen()) + assert resp.status == 200 diff --git a/tests/test_client_request.py b/tests/test_client_request.py index f54ea0150cc..b862434da4d 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -10,6 +10,7 @@ from unittest import mock import pytest +from async_generator import async_generator, yield_ from multidict import CIMultiDict, CIMultiDictProxy, istr from yarl import URL @@ -834,10 +835,31 @@ async def test_expect_100_continue_header(loop, conn): async def test_data_stream(loop, buf, conn): - @aiohttp.streamer - async def gen(writer): - await writer.write(b'binary data') - await writer.write(b' result') + @async_generator + async def gen(): + await yield_(b'binary data') + await yield_(b' result') + + req = ClientRequest( + 'POST', URL('http://python.org/'), data=gen(), loop=loop) + assert req.chunked + assert req.headers['TRANSFER-ENCODING'] == 'chunked' + + resp = await req.send(conn) + assert asyncio.isfuture(req._writer) + await resp.wait_for_close() + assert req._writer is None + assert buf.split(b'\r\n\r\n', 1)[1] == \ + b'b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n' + await req.close() + + +async def test_data_stream_deprecated(loop, buf, conn): + with pytest.warns(DeprecationWarning): + @aiohttp.streamer + async def gen(writer): + await writer.write(b'binary data') + await writer.write(b' result') req = ClientRequest( 'POST', URL('http://python.org/'), data=gen(), loop=loop) @@ -874,9 +896,9 @@ async def test_data_file(loop, buf, conn): async def test_data_stream_exc(loop, conn): fut = loop.create_future() - @aiohttp.streamer - async def gen(writer): - await writer.write(b'binary data') + @async_generator + async def gen(): + await yield_(b'binary data') await fut req = ClientRequest( @@ -897,11 +919,38 @@ async def throw_exc(): await req.close() +async def test_data_stream_exc_deprecated(loop, conn): + fut = loop.create_future() + + with pytest.warns(DeprecationWarning): + @aiohttp.streamer + async def gen(writer): + await writer.write(b'binary data') + await fut + + req = ClientRequest( + 'POST', URL('http://python.org/'), data=gen(), loop=loop) + assert req.chunked + assert req.headers['TRANSFER-ENCODING'] == 'chunked' + + async def throw_exc(): + await asyncio.sleep(0.01, loop=loop) + fut.set_exception(ValueError) + + loop.create_task(throw_exc()) + + await req.send(conn) + await req._writer + # assert conn.close.called + assert conn.protocol.set_exception.called + await req.close() + + async def test_data_stream_exc_chain(loop, conn): fut = loop.create_future() - @aiohttp.streamer - async def gen(writer): + @async_generator + async def gen(): await fut req = ClientRequest('POST', URL('http://python.org/'), @@ -926,12 +975,68 @@ async def throw_exc(): await req.close() +async def test_data_stream_exc_chain_deprecated(loop, conn): + fut = loop.create_future() + + with pytest.warns(DeprecationWarning): + @aiohttp.streamer + async def gen(writer): + await fut + + req = ClientRequest('POST', URL('http://python.org/'), + data=gen(), loop=loop) + + inner_exc = ValueError() + + async def throw_exc(): + await asyncio.sleep(0.01, loop=loop) + fut.set_exception(inner_exc) + + loop.create_task(throw_exc()) + + await req.send(conn) + await req._writer + # assert connection.close.called + assert conn.protocol.set_exception.called + outer_exc = conn.protocol.set_exception.call_args[0][0] + assert isinstance(outer_exc, ValueError) + assert inner_exc is outer_exc + assert inner_exc is outer_exc + await req.close() + + async def test_data_stream_continue(loop, buf, conn): - @aiohttp.streamer - async def gen(writer): - await writer.write(b'binary data') - await writer.write(b' result') - await writer.write_eof() + @async_generator + async def gen(): + await yield_(b'binary data') + await yield_(b' result') + + req = ClientRequest( + 'POST', URL('http://python.org/'), data=gen(), + expect100=True, loop=loop) + assert req.chunked + + async def coro(): + await asyncio.sleep(0.0001, loop=loop) + req._continue.set_result(1) + + loop.create_task(coro()) + + resp = await req.send(conn) + await req._writer + assert buf.split(b'\r\n\r\n', 1)[1] == \ + b'b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n' + await req.close() + resp.close() + + +async def test_data_stream_continue_deprecated(loop, buf, conn): + with pytest.warns(DeprecationWarning): + @aiohttp.streamer + async def gen(writer): + await writer.write(b'binary data') + await writer.write(b' result') + await writer.write_eof() req = ClientRequest( 'POST', URL('http://python.org/'), data=gen(), @@ -972,10 +1077,26 @@ async def coro(): async def test_close(loop, buf, conn): - @aiohttp.streamer - async def gen(writer): - await asyncio.sleep(0.00001, loop=loop) - await writer.write(b'result') + @async_generator + async def gen(): + await asyncio.sleep(0.00001) + await yield_(b'result') + + req = ClientRequest( + 'POST', URL('http://python.org/'), data=gen(), loop=loop) + resp = await req.send(conn) + await req.close() + assert buf.split(b'\r\n\r\n', 1)[1] == b'6\r\nresult\r\n0\r\n\r\n' + await req.close() + resp.close() + + +async def test_close_deprecated(loop, buf, conn): + with pytest.warns(DeprecationWarning): + @aiohttp.streamer + async def gen(writer): + await asyncio.sleep(0.00001, loop=loop) + await writer.write(b'result') req = ClientRequest( 'POST', URL('http://python.org/'), data=gen(), loop=loop) diff --git a/tests/test_payload.py b/tests/test_payload.py index bba666a17d6..26efd2a5221 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -1,6 +1,7 @@ from io import StringIO import pytest +from async_generator import async_generator from aiohttp import payload @@ -28,6 +29,14 @@ class TestProvider: assert isinstance(p, Payload) +def test_register_unsupported_order(registry): + class TestProvider: + pass + + with pytest.raises(ValueError): + payload.register_payload(Payload, TestProvider, order=object()) + + def test_payload_ctor(): p = Payload('test', encoding='utf-8', filename='test.txt') assert p._value == 'test' @@ -42,6 +51,21 @@ def test_payload_content_type(): assert p.content_type == 'application/json' +def test_bytes_payload_default_content_type(): + p = payload.BytesPayload(b'data') + assert p.content_type == 'application/octet-stream' + + +def test_bytes_payload_explicit_content_type(): + p = payload.BytesPayload(b'data', content_type='application/custom') + assert p.content_type == 'application/custom' + + +def test_bytes_payload_bad_type(): + with pytest.raises(TypeError): + payload.BytesPayload(object()) + + def test_string_payload(): p = payload.StringPayload('test') assert p.encoding == 'utf-8' @@ -63,3 +87,27 @@ def test_string_io_payload(): assert p.encoding == 'utf-8' assert p.content_type == 'text/plain; charset=utf-8' assert p.size == 10000 + + +def test_async_iterable_payload_default_content_type(): + @async_generator + async def gen(): + pass + + p = payload.AsyncIterablePayload(gen()) + assert p.content_type == 'application/octet-stream' + + +def test_async_iterable_payload_explicit_content_type(): + @async_generator + async def gen(): + pass + + p = payload.AsyncIterablePayload(gen(), content_type='application/custom') + assert p.content_type == 'application/custom' + + +def test_async_iterable_payload_not_async_iterable(): + + with pytest.raises(TypeError): + payload.AsyncIterablePayload(object()) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 4183fcdea11..4999c043912 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -7,6 +7,7 @@ from unittest import mock import pytest +from async_generator import async_generator, yield_ from multidict import MultiDict from yarl import URL @@ -730,19 +731,19 @@ async def handler(request): assert 200 == resp.status -async def test_response_with_streamer(aiohttp_client, fname): +async def test_response_with_async_gen(aiohttp_client, fname): with fname.open('rb') as f: data = f.read() data_size = len(data) - @aiohttp.streamer - async def stream(writer, f_name): + @async_generator + async def stream(f_name): with f_name.open('rb') as f: data = f.read(100) while data: - await writer.write(data) + await yield_(data) data = f.read(100) async def handler(request): @@ -760,20 +761,82 @@ async def handler(request): assert resp.headers.get('Content-Length') == str(len(resp_data)) -async def test_response_with_streamer_no_params(aiohttp_client, fname): +async def test_response_with_streamer(aiohttp_client, fname): with fname.open('rb') as f: data = f.read() data_size = len(data) - @aiohttp.streamer - async def stream(writer): + with pytest.warns(DeprecationWarning): + @aiohttp.streamer + async def stream(writer, f_name): + with f_name.open('rb') as f: + data = f.read(100) + while data: + await writer.write(data) + data = f.read(100) + + async def handler(request): + headers = {'Content-Length': str(data_size)} + return web.Response(body=stream(fname), headers=headers) + + app = web.Application() + app.router.add_get('/', handler) + client = await aiohttp_client(app) + + resp = await client.get('/') + assert 200 == resp.status + resp_data = await resp.read() + assert resp_data == data + assert resp.headers.get('Content-Length') == str(len(resp_data)) + + +async def test_response_with_async_gen_no_params(aiohttp_client, fname): + + with fname.open('rb') as f: + data = f.read() + + data_size = len(data) + + @async_generator + async def stream(): with fname.open('rb') as f: data = f.read(100) while data: - await writer.write(data) + await yield_(data) + data = f.read(100) + + async def handler(request): + headers = {'Content-Length': str(data_size)} + return web.Response(body=stream(), headers=headers) + + app = web.Application() + app.router.add_get('/', handler) + client = await aiohttp_client(app) + + resp = await client.get('/') + assert 200 == resp.status + resp_data = await resp.read() + assert resp_data == data + assert resp.headers.get('Content-Length') == str(len(resp_data)) + + +async def test_response_with_streamer_no_params(aiohttp_client, fname): + + with fname.open('rb') as f: + data = f.read() + + data_size = len(data) + + with pytest.warns(DeprecationWarning): + @aiohttp.streamer + async def stream(writer): + with fname.open('rb') as f: data = f.read(100) + while data: + await writer.write(data) + data = f.read(100) async def handler(request): headers = {'Content-Length': str(data_size)}