Skip to content

Commit

Permalink
refactor MultipartWriter to use Payload
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Feb 21, 2017
1 parent 7e9a381 commit 0e765a1
Show file tree
Hide file tree
Showing 12 changed files with 787 additions and 863 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ cov-dev-full: .develop
@echo "Run without extensions"
@AIOHTTP_NO_EXTENSIONS=1 py.test --cov=aiohttp tests
@echo "Run in debug mode"
@PYTHONASYNCIODEBUG=1 py.test --cov=aiohttp --cov-append tests
@PYTHONASYNCIODEBUG=1 py.test -s -v --cov=aiohttp --cov-append tests
@echo "Regular run"
@py.test --cov=aiohttp --cov-report=term --cov-report=html --cov-append tests
@echo "open file://`pwd`/coverage/index.html"
Expand Down
6 changes: 4 additions & 2 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import hdrs # noqa
from .client import * # noqa
from .formdata import * # noqa
from .helpers import * # noqa
from .http_message import HttpVersion, HttpVersion10, HttpVersion11 # noqa
from .http_websocket import WSMsgType, WSCloseCode, WSMessage, WebSocketError # noqa
Expand All @@ -25,11 +26,12 @@


__all__ = (client.__all__ + # noqa
formdata.__all__ + # noqa
helpers.__all__ + # noqa
streams.__all__ + # noqa
multipart.__all__ + # noqa
payload.__all__ + # noqa
payload_streamer.__all__ + # noqa
multipart.__all__ + # noqa
streams.__all__ + # noqa
('hdrs', 'FileSender',
'HttpVersion', 'HttpVersion10', 'HttpVersion11',
'WSMsgType', 'MsgType', 'WSCloseCode',
Expand Down
85 changes: 34 additions & 51 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import aiohttp

from . import hdrs, helpers, http, payload
from .formdata import FormData
from .helpers import PY_35, HeadersMixin, SimpleCookie, _TimeServiceTimeoutNoop
from .http import HttpMessage
from .log import client_logger
from .multipart import MultipartWriter
from .streams import FlowControlStreamReader

try:
Expand Down Expand Up @@ -217,71 +217,54 @@ def update_auth(self, auth):

self.headers[hdrs.AUTHORIZATION] = auth.encode()

def update_body_from_data(self, data, skip_auto_headers):
if not data:
return

try:
self.body = payload.PAYLOAD_REGISTRY.get(data)

# enable chunked encoding if needed
if not self.chunked:
if hdrs.CONTENT_LENGTH not in self.headers:
size = self.body.size
if size is None:
self.chunked = True
else:
if hdrs.CONTENT_LENGTH not in self.headers:
self.headers[hdrs.CONTENT_LENGTH] = str(size)

# set content-type
if (hdrs.CONTENT_TYPE not in self.headers and
hdrs.CONTENT_TYPE not in skip_auto_headers):
self.headers[hdrs.CONTENT_TYPE] = self.body.content_type

# copy payload headers
if self.body.headers:
for (key, value) in self.body.headers.items():
if key not in self.headers:
self.headers[key] = value

except payload.LookupError:
pass
else:
def update_body_from_data(self, body, skip_auto_headers):
if not body:
return

if asyncio.iscoroutine(data):
if asyncio.iscoroutine(body):
warnings.warn(
'coroutine as data object is deprecated, '
'use aiohttp.streamer #1664',
DeprecationWarning, stacklevel=2)

self.body = data
self.body = body
if (hdrs.CONTENT_LENGTH not in self.headers and
self.chunked is None):
self.chunked = True

elif isinstance(data, MultipartWriter):
self.body = data.serialize()
self.headers.update(data.headers)
self.chunked = True
return

else:
if not isinstance(data, helpers.FormData):
data = helpers.FormData(data)
# FormData
if isinstance(body, FormData):
body = body(self.encoding)

self.body = data(self.encoding)
try:
body = payload.PAYLOAD_REGISTRY.get(body)
except payload.LookupError:
body = FormData(body)(self.encoding)

if (hdrs.CONTENT_TYPE not in self.headers and
hdrs.CONTENT_TYPE not in skip_auto_headers):
self.headers[hdrs.CONTENT_TYPE] = data.content_type
self.body = body

if data.is_multipart:
self.chunked = True
else:
if (hdrs.CONTENT_LENGTH not in self.headers and
not self.chunked):
self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
# enable chunked encoding if needed
if not self.chunked:
if hdrs.CONTENT_LENGTH not in self.headers:
size = body.size
if size is None:
self.chunked = True
else:
if hdrs.CONTENT_LENGTH not in self.headers:
self.headers[hdrs.CONTENT_LENGTH] = str(size)

# set content-type
if (hdrs.CONTENT_TYPE not in self.headers and
hdrs.CONTENT_TYPE not in skip_auto_headers):
self.headers[hdrs.CONTENT_TYPE] = body.content_type

# copy payload headers
if body.headers:
for (key, value) in body.headers.items():
if key not in self.headers:
self.headers[key] = value

def update_transfer_encoding(self):
"""Analyze transfer-encoding header."""
Expand Down
122 changes: 122 additions & 0 deletions aiohttp/formdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import io
from urllib.parse import urlencode

from multidict import MultiDict, MultiDictProxy

from . import hdrs, multipart, payload
from .helpers import guess_filename

__all__ = ('FormData',)


class FormData:
"""Helper class for multipart/form-data and
application/x-www-form-urlencoded body generation."""

def __init__(self, fields=(), quote_fields=True):
self._writer = multipart.MultipartWriter('form-data')
self._fields = []
self._is_multipart = False
self._quote_fields = quote_fields

if isinstance(fields, dict):
fields = list(fields.items())
elif not isinstance(fields, (list, tuple)):
fields = (fields,)
self.add_fields(*fields)

def add_field(self, name, value, *, content_type=None, filename=None,
content_transfer_encoding=None):

if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
if filename is None and content_transfer_encoding is None:
filename = name

type_options = MultiDict({'name': name})
if filename is not None and not isinstance(filename, str):
raise TypeError('filename must be an instance of str. '
'Got: %s' % filename)
if filename is None and isinstance(value, io.IOBase):
filename = guess_filename(value, name)
if filename is not None:
type_options['filename'] = filename
self._is_multipart = True

headers = {}
if content_type is not None:
if not isinstance(content_type, str):
raise TypeError('content_type must be an instance of str. '
'Got: %s' % content_type)
headers[hdrs.CONTENT_TYPE] = content_type
self._is_multipart = True
if content_transfer_encoding is not None:
if not isinstance(content_transfer_encoding, str):
raise TypeError('content_transfer_encoding must be an instance'
' of str. Got: %s' % content_transfer_encoding)
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
self._is_multipart = True

self._fields.append((type_options, headers, value))

def add_fields(self, *fields):
to_add = list(fields)

while to_add:
rec = to_add.pop(0)

if isinstance(rec, io.IOBase):
k = guess_filename(rec, 'unknown')
self.add_field(k, rec)

elif isinstance(rec, (MultiDictProxy, MultiDict)):
to_add.extend(rec.items())

elif isinstance(rec, (list, tuple)) and len(rec) == 2:
k, fp = rec
self.add_field(k, fp)

else:
raise TypeError('Only io.IOBase, multidict and (name, file) '
'pairs allowed, use .add_field() for passing '
'more complex parameters, got {!r}'
.format(rec))

def _gen_form_urlencoded(self, encoding):
# form data (x-www-form-urlencoded)
data = []
for type_options, _, value in self._fields:
data.append((type_options['name'], value))

return payload.BytesPayload(
urlencode(data, doseq=True).encode(encoding),
content_type='application/x-www-form-urlencoded')

def _gen_form_data(self, encoding):
"""Encode a list of fields using the multipart/form-data MIME format"""
for dispparams, headers, value in self._fields:
if hdrs.CONTENT_TYPE in headers:
part = payload.get_payload(
value, content_type=headers[hdrs.CONTENT_TYPE],
headers=headers, encoding=encoding)
else:
part = payload.get_payload(
value, headers=headers, encoding=encoding)
if dispparams:
part.set_content_disposition(
'form-data', quote_fields=self._quote_fields, **dispparams
)
# FIXME cgi.FieldStorage doesn't likes body parts with
# Content-Length which were sent via chunked transfer encoding
part.headers.pop(hdrs.CONTENT_LENGTH, None)

self._writer.append_payload(part)

return self._writer

def __call__(self, encoding):
if self._is_multipart:
return self._gen_form_data(encoding)
else:
return self._gen_form_urlencoded(encoding)
Loading

0 comments on commit 0e765a1

Please sign in to comment.