From 50833f28f261362890aee2d603686e38d828bbdd Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 21 Feb 2017 22:51:01 -0800 Subject: [PATCH] add Payload wrapper for BodyPartReader and more tests --- aiohttp/multipart.py | 26 +++++++++++++++++++++++++- aiohttp/payload.py | 16 +++++++++++++++- aiohttp/payload_streamer.py | 10 ++++------ tests/test_multipart.py | 13 +++++++++++++ tests/test_web_functional.py | 26 ++++++++++++++++++++++++++ 5 files changed, 83 insertions(+), 8 deletions(-) diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 02fe9620186..2670f2a94c1 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -16,7 +16,7 @@ from .helpers import CHAR, PY_35, PY_352, TOKEN, parse_mimetype, reify from .http import HttpParser from .payload import (BytesPayload, LookupError, Payload, StringPayload, - get_payload) + get_payload, payload_type) __all__ = ('MultipartReader', 'MultipartWriter', 'BodyPartReader', 'BadContentDispositionHeader', 'BadContentDispositionParam', @@ -485,6 +485,30 @@ def filename(self): return content_disposition_filename(params, 'filename') +@payload_type(BodyPartReader) +class BodyPartReaderPayload(Payload): + + def __init__(self, value, *args, **kwargs): + super().__init__(value, *args, **kwargs) + + params = {} + if value.name is not None: + params['name'] = value.name + if value.filename is not None: + params['filename'] = value.name + + if params: + self.set_content_disposition('attachment', **params) + + @asyncio.coroutine + def write(self, writer): + field = self._value + chunk = yield from field.read_chunk(size=2**16) + while chunk: + writer.write(field.decode(chunk)) + chunk = yield from field.read_chunk(size=2**16) + + class MultipartReader(object): """Multipart body reader.""" diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 9924448ef45..bd5426164a4 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -11,7 +11,7 @@ parse_mimetype, sentinel) from .streams import DEFAULT_LIMIT, DataQueue, EofStream, StreamReader -__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'Payload', +__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload', 'BytesPayload', 'StringPayload', 'StreamReaderPayload', 'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload', 'TextIOPayload', 'StringIOPayload') @@ -25,6 +25,20 @@ def get_payload(data, *args, **kwargs): return PAYLOAD_REGISTRY.get(data, *args, **kwargs) +def register_payload(ctor, type): + PAYLOAD_REGISTRY.register(ctor, type) + + +class payload_type: + + def __init__(self, type): + self.type = type + + def __call__(self, cls): + PAYLOAD_REGISTRY.register(cls, self.type) + return cls + + class PayloadRegistry: """Payload registry. diff --git a/aiohttp/payload_streamer.py b/aiohttp/payload_streamer.py index 53028115146..5cfc6814503 100644 --- a/aiohttp/payload_streamer.py +++ b/aiohttp/payload_streamer.py @@ -23,7 +23,7 @@ def file_sender(writer, file_name=None): import asyncio -from . import payload +from .payload import Payload, payload_type __all__ = ('streamer',) @@ -49,13 +49,15 @@ def __call__(self, *args, **kwargs): return _stream_wrapper(self.coro, args, kwargs) -class StreamWrapperPayload(payload.Payload): +@payload_type(_stream_wrapper) +class StreamWrapperPayload(Payload): @asyncio.coroutine def write(self, writer): yield from self._value(writer) +@payload_type(streamer) class StreamPayload(StreamWrapperPayload): def __init__(self, value, *args, **kwargs): @@ -64,7 +66,3 @@ def __init__(self, value, *args, **kwargs): @asyncio.coroutine def write(self, writer): yield from self._value(writer) - - -payload.PAYLOAD_REGISTRY.register(StreamPayload, streamer) -payload.PAYLOAD_REGISTRY.register(StreamWrapperPayload, _stream_wrapper) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 36379d8cde2..0bb7115514f 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -492,6 +492,19 @@ def test_read_form_while_closed(self): result = yield from obj.form() self.assertEqual(None, result) + def test_readline(self): + obj = aiohttp.multipart.BodyPartReader( + self.boundary, {}, Stream(b'Hello\n,\r\nworld!\r\n--:--')) + result = yield from obj.readline() + self.assertEqual(b'Hello\n', result) + result = yield from obj.readline() + self.assertEqual(b',\r\n', result) + result = yield from obj.readline() + self.assertEqual(b'world!', result) + result = yield from obj.readline() + self.assertEqual(b'', result) + self.assertTrue(obj.at_eof()) + def test_release(self): stream = Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--') obj = aiohttp.multipart.BodyPartReader( diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 6685d5c49da..132e99bd719 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1481,3 +1481,29 @@ def handler(request): resp = yield from client.post('/', data=data) assert 200 == resp.status + + +@asyncio.coroutine +def test_response_with_bodypart(loop, test_client): + + @asyncio.coroutine + def handler(request): + reader = yield from request.multipart() + part = yield from reader.next() + return web.Response(body=part) + + app = web.Application(loop=loop, client_max_size=2) + app.router.add_post('/', handler) + client = yield from test_client(app) + + data = {'file': io.BytesIO(b'test')} + resp = yield from client.post('/', data=data) + + assert 200 == resp.status + body = yield from resp.read() + assert body == b'test' + + disp = multipart.parse_content_disposition( + resp.headers['content-disposition']) + assert disp == ('attachment', + {'name': 'file', 'filename': 'file', 'filename*': 'file'})