Skip to content

Commit

Permalink
add Payload wrapper for BodyPartReader and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Feb 22, 2017
1 parent 6b65dda commit 50833f2
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 8 deletions.
26 changes: 25 additions & 1 deletion aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .helpers import CHAR, PY_35, PY_352, TOKEN, parse_mimetype, reify
from .http import HttpParser
from .payload import (BytesPayload, LookupError, Payload, StringPayload,
get_payload)
get_payload, payload_type)

__all__ = ('MultipartReader', 'MultipartWriter', 'BodyPartReader',
'BadContentDispositionHeader', 'BadContentDispositionParam',
Expand Down Expand Up @@ -485,6 +485,30 @@ def filename(self):
return content_disposition_filename(params, 'filename')


@payload_type(BodyPartReader)
class BodyPartReaderPayload(Payload):

def __init__(self, value, *args, **kwargs):
super().__init__(value, *args, **kwargs)

params = {}
if value.name is not None:
params['name'] = value.name
if value.filename is not None:
params['filename'] = value.name

if params:
self.set_content_disposition('attachment', **params)

@asyncio.coroutine
def write(self, writer):
field = self._value
chunk = yield from field.read_chunk(size=2**16)
while chunk:
writer.write(field.decode(chunk))
chunk = yield from field.read_chunk(size=2**16)


class MultipartReader(object):
"""Multipart body reader."""

Expand Down
16 changes: 15 additions & 1 deletion aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
parse_mimetype, sentinel)
from .streams import DEFAULT_LIMIT, DataQueue, EofStream, StreamReader

__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'Payload',
__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload',
'BytesPayload', 'StringPayload', 'StreamReaderPayload',
'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload',
'TextIOPayload', 'StringIOPayload')
Expand All @@ -25,6 +25,20 @@ def get_payload(data, *args, **kwargs):
return PAYLOAD_REGISTRY.get(data, *args, **kwargs)


def register_payload(ctor, type):
PAYLOAD_REGISTRY.register(ctor, type)


class payload_type:

def __init__(self, type):
self.type = type

def __call__(self, cls):
PAYLOAD_REGISTRY.register(cls, self.type)
return cls


class PayloadRegistry:
"""Payload registry.
Expand Down
10 changes: 4 additions & 6 deletions aiohttp/payload_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def file_sender(writer, file_name=None):

import asyncio

from . import payload
from .payload import Payload, payload_type

__all__ = ('streamer',)

Expand All @@ -49,13 +49,15 @@ def __call__(self, *args, **kwargs):
return _stream_wrapper(self.coro, args, kwargs)


class StreamWrapperPayload(payload.Payload):
@payload_type(_stream_wrapper)
class StreamWrapperPayload(Payload):

@asyncio.coroutine
def write(self, writer):
yield from self._value(writer)


@payload_type(streamer)
class StreamPayload(StreamWrapperPayload):

def __init__(self, value, *args, **kwargs):
Expand All @@ -64,7 +66,3 @@ def __init__(self, value, *args, **kwargs):
@asyncio.coroutine
def write(self, writer):
yield from self._value(writer)


payload.PAYLOAD_REGISTRY.register(StreamPayload, streamer)
payload.PAYLOAD_REGISTRY.register(StreamWrapperPayload, _stream_wrapper)
13 changes: 13 additions & 0 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,19 @@ def test_read_form_while_closed(self):
result = yield from obj.form()
self.assertEqual(None, result)

def test_readline(self):
obj = aiohttp.multipart.BodyPartReader(
self.boundary, {}, Stream(b'Hello\n,\r\nworld!\r\n--:--'))
result = yield from obj.readline()
self.assertEqual(b'Hello\n', result)
result = yield from obj.readline()
self.assertEqual(b',\r\n', result)
result = yield from obj.readline()
self.assertEqual(b'world!', result)
result = yield from obj.readline()
self.assertEqual(b'', result)
self.assertTrue(obj.at_eof())

def test_release(self):
stream = Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--')
obj = aiohttp.multipart.BodyPartReader(
Expand Down
26 changes: 26 additions & 0 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,3 +1481,29 @@ def handler(request):
resp = yield from client.post('/', data=data)

assert 200 == resp.status


@asyncio.coroutine
def test_response_with_bodypart(loop, test_client):

@asyncio.coroutine
def handler(request):
reader = yield from request.multipart()
part = yield from reader.next()
return web.Response(body=part)

app = web.Application(loop=loop, client_max_size=2)
app.router.add_post('/', handler)
client = yield from test_client(app)

data = {'file': io.BytesIO(b'test')}
resp = yield from client.post('/', data=data)

assert 200 == resp.status
body = yield from resp.read()
assert body == b'test'

disp = multipart.parse_content_disposition(
resp.headers['content-disposition'])
assert disp == ('attachment',
{'name': 'file', 'filename': 'file', 'filename*': 'file'})

0 comments on commit 50833f2

Please sign in to comment.