Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typings for Payload #3294

Merged
merged 7 commits into from
Oct 1, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/1749.feature
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Add type hints to Application and Response
Add type hints to Exceptions
Upgrade mypy to 0.630
Add type hints to payload.py
184 changes: 129 additions & 55 deletions aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
import os
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable
from itertools import chain
from typing import (IO, TYPE_CHECKING, Any, ByteString, Callable, Dict, # noqa
Iterable, List, Optional, Text, TextIO, Tuple, Type, Union)

from multidict import CIMultiDict

from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import (PY_36, content_disposition_header, guess_filename,
parse_mimetype, sentinel)
from .streams import DEFAULT_LIMIT, StreamReader
from .typedefs import JSON, JSONEncoder


__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload',
Expand All @@ -29,27 +32,30 @@ class LookupError(Exception):
pass


class Order(enum.Enum):
class Order(str, enum.Enum):
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
normal = 'normal'
try_first = 'try_first'
try_last = 'try_last'


def get_payload(data, *args, **kwargs):
def get_payload(data: Any, *args: Any, **kwargs: Any) -> 'Payload':
return PAYLOAD_REGISTRY.get(data, *args, **kwargs)


def register_payload(factory, type, *, order=Order.normal):
def register_payload(factory: Type['Payload'],
type: Any,
*,
order: Order=Order.normal) -> None:
PAYLOAD_REGISTRY.register(factory, type, order=order)


class payload_type:

def __init__(self, type, *, order=Order.normal):
def __init__(self, type: Any, *, order: Order=Order.normal) -> None:
self.type = type
self.order = order

def __call__(self, factory):
def __call__(self, factory: Type['Payload']) -> Type['Payload']:
register_payload(factory, self.type, order=self.order)
return factory

Expand All @@ -60,12 +66,16 @@ class PayloadRegistry:
note: we need zope.interface for more efficient adapter search
"""

def __init__(self):
self._first = []
self._normal = []
self._last = []
def __init__(self) -> None:
self._first = [] # type: List[Tuple[Type[Payload], Any]]
self._normal = [] # type: List[Tuple[Type[Payload], Any]]
self._last = [] # type: List[Tuple[Type[Payload], Any]]

def get(self, data, *args, _CHAIN=chain, **kwargs):
def get(self,
data: Any,
*args: Any,
_CHAIN: Any=chain,
**kwargs: Any) -> 'Payload':
if isinstance(data, Payload):
return data
for factory, type in _CHAIN(self._first, self._normal, self._last):
Expand All @@ -74,7 +84,11 @@ def get(self, data, *args, _CHAIN=chain, **kwargs):

raise LookupError()

def register(self, factory, type, *, order=Order.normal):
def register(self,
factory: Type['Payload'],
type: Any,
*,
order: Order=Order.normal) -> None:
if order is Order.try_first:
self._first.append((factory, type))
elif order is Order.normal:
Expand All @@ -87,12 +101,23 @@ def register(self, factory, type, *, order=Order.normal):

class Payload(ABC):

_size = None
_headers = None
_content_type = 'application/octet-stream'

def __init__(self, value, *, headers=None, content_type=sentinel,
filename=None, encoding=None, **kwargs):
_size = None # type: Optional[float]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit late for the party, but why size is typed as float? This value is used for Content-Length header which defined number of bytes of payload size - it's hard to pass over HTTP some fraction of a byte. What was the case for float type?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_headers = None # type: Optional[CIMultiDict]
_content_type = 'application/octet-stream' # type: Optional[str]

def __init__(self,
value: Any,
headers: Optional[
Union[
CIMultiDict,
Dict[str, Any],
Iterable[Tuple[str, Any]]
]
] = None,
content_type: Optional[str]=sentinel,
filename: Optional[str]=None,
encoding: Optional[str]=None,
**kwargs: Any) -> None:
self._value = value
self._encoding = encoding
self._filename = filename
Expand All @@ -107,27 +132,27 @@ def __init__(self, value, *, headers=None, content_type=sentinel,
self._content_type = content_type

@property
def size(self):
def size(self) -> Optional[float]:
Copy link
Contributor Author

@kornicameister kornicameister Sep 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All those Optional seem a little awkward, but honestly speaking that's how I read the logic. Might as well be that I got that wrong.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a pretty common case.
For example AsyncIterablePayload has no size but sends a payload chunk-by-chunk

"""Size of the payload."""
return self._size

@property
def filename(self):
def filename(self) -> Optional[str]:
"""Filename of the payload."""
return self._filename

@property
def headers(self):
def headers(self) -> Optional[CIMultiDict]:
"""Custom item headers"""
return self._headers

@property
def encoding(self):
def encoding(self) -> Optional[str]:
"""Payload encoding"""
return self._encoding

@property
def content_type(self):
def content_type(self) -> Optional[str]:
"""Content type"""
if self._content_type is not None:
return self._content_type
Expand All @@ -137,7 +162,10 @@ def content_type(self):
else:
return Payload._content_type

def set_content_disposition(self, disptype, quote_fields=True, **params):
def set_content_disposition(self,
disptype: str,
quote_fields: bool=True,
**params: Any) -> None:
"""Sets ``Content-Disposition`` header."""
if self._headers is None:
self._headers = CIMultiDict()
Expand All @@ -146,7 +174,7 @@ def set_content_disposition(self, disptype, quote_fields=True, **params):
disptype, quote_fields=quote_fields, **params)

@abstractmethod
async def write(self, writer):
async def write(self, writer: AbstractStreamWriter) -> None:
"""Write payload.

writer is an AbstractStreamWriter instance:
Expand All @@ -155,7 +183,10 @@ async def write(self, writer):

class BytesPayload(Payload):

def __init__(self, value, *args, **kwargs):
def __init__(self,
value: ByteString,
*args: Any,
**kwargs: Any) -> None:
if not isinstance(value, (bytes, bytearray, memoryview)):
raise TypeError("value argument must be byte-ish, not (!r)"
.format(type(value)))
Expand All @@ -177,14 +208,18 @@ def __init__(self, value, *args, **kwargs):
"io.BytesIO object instead", ResourceWarning,
**kwargs)

async def write(self, writer):
async def write(self, writer: AbstractStreamWriter) -> None:
await writer.write(self._value)


class StringPayload(BytesPayload):

def __init__(self, value, *args,
encoding=None, content_type=None, **kwargs):
def __init__(self,
value: Text,
*args: Any,
encoding: Optional[str]=None,
content_type: Optional[str]=None,
**kwargs: Any) -> None:

if encoding is None:
if content_type is None:
Expand All @@ -197,20 +232,33 @@ def __init__(self, value, *args,
if content_type is None:
content_type = 'text/plain; charset=%s' % encoding

super().__init__(
value.encode(encoding),
encoding=encoding, content_type=content_type, *args, **kwargs)
if encoding is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this bit I completely do not get. It's literally impossible for encoding to be None, based on the logic. But the typing says otherwise and we can't exactly redefine the value with mypy. We can't also do the

_encoding: str
_content_type: str

letting the assignment to come later, because tests on 3.5 complains. We could've do the typing.cast but not really sure if that's correct. Sounds a bit hacky to me.

@asvetlov any suggestions ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we can keep that raise ValueError and add some noqa there, but honestly that sounds bad as well to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, and for the sake of discussion, this is the alternative:

class StringPayload(BytesPayload):

    def __init__(self,
                 value: str,
                 *args: Any,
                 encoding: Optional[str]=None,
                 content_type: Optional[str]=None,
                 **kwargs: Any) -> None:
        _encoding, _content_type = self._parse_args(encoding, content_type)
        super().__init__(
            value.encode(_encoding),
            encoding=_encoding,
            content_type=_content_type,
            *args,
            **kwargs)

    @staticmethod
    def _parse_args(
        encoding: Optional[str]=None,
        content_type: Optional[str]=None,
    ) -> Tuple[str, str]:
        if encoding and content_type:
            return encoding, content_type
        elif encoding is None and content_type is None:
            return 'utf-8', 'text/plain; charset=utf-8'
        elif encoding is None and content_type:
            mimetype = parse_mimetype(content_type)
            return (
                mimetype.parameters.get('charset', 'utf-8'),
                content_type,
            )
        else:
            return (
                cast(str, encoding),
                'text/plain; charset=%s' % cast(str, encoding),
            )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert back the change, drop if encoding is None check.
Use the following trick:

        if encoding is None:
            if content_type is None:
                 real_encoding = 'utf-8'
                 content_type = 'text/plain; charset=utf-8'
             else:
                 mimetype = parse_mimetype(content_type)
                 readl_encoding = mimetype.parameters.get('charset', 'utf-8')
         else:
            if content_type is None:
                content_type = 'text/plain; charset=%s' % encoding
            real_encoding = encoding

 super().__init__(
                value.encode(encoding),
                encoding=real_encoding,
                content_type=content_type,
                *args,
                **kwargs)

real_encoding variable has strict string type, not Optional[str].

I believe mypy understands it pretty well.

raise ValueError('Encoding must be set')
else:
super().__init__(
value.encode(encoding),
encoding=encoding,
content_type=content_type,
*args,
**kwargs)


class StringIOPayload(StringPayload):

def __init__(self, value, *args, **kwargs):
def __init__(self,
value: IO[str],
*args: Any,
**kwargs: Any) -> None:
super().__init__(value.read(), *args, **kwargs)


class IOBasePayload(Payload):

def __init__(self, value, disposition='attachment', *args, **kwargs):
def __init__(self,
value: IO[Any],
disposition: str='attachment',
*args: Any,
**kwargs: Any) -> None:
if 'filename' not in kwargs:
kwargs['filename'] = guess_filename(value)

Expand All @@ -219,7 +267,7 @@ def __init__(self, value, disposition='attachment', *args, **kwargs):
if self._filename is not None and disposition is not None:
self.set_content_disposition(disposition, filename=self._filename)

async def write(self, writer):
async def write(self, writer: AbstractStreamWriter) -> None:
try:
chunk = self._value.read(DEFAULT_LIMIT)
while chunk:
Expand All @@ -231,8 +279,12 @@ async def write(self, writer):

class TextIOPayload(IOBasePayload):

def __init__(self, value, *args,
encoding=None, content_type=None, **kwargs):
def __init__(self,
value: TextIO,
*args: Any,
encoding: Optional[str]=None,
content_type: Optional[str]=None,
**kwargs: Any) -> None:

if encoding is None:
if content_type is None:
Expand All @@ -250,13 +302,13 @@ def __init__(self, value, *args,
content_type=content_type, encoding=encoding, *args, **kwargs)

@property
def size(self):
def size(self) -> Optional[float]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
except OSError:
return None

async def write(self, writer):
async def write(self, writer: AbstractStreamWriter) -> None:
try:
chunk = self._value.read(DEFAULT_LIMIT)
while chunk:
Expand All @@ -269,7 +321,7 @@ async def write(self, writer):
class BytesIOPayload(IOBasePayload):

@property
def size(self):
def size(self) -> float:
position = self._value.tell()
end = self._value.seek(0, os.SEEK_END)
self._value.seek(position)
Expand All @@ -279,7 +331,7 @@ def size(self):
class BufferedReaderPayload(IOBasePayload):

@property
def size(self):
def size(self) -> Optional[float]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
except OSError:
Expand All @@ -290,18 +342,39 @@ def size(self):

class JsonPayload(BytesPayload):

def __init__(self, value,
encoding='utf-8', content_type='application/json',
dumps=json.dumps, *args, **kwargs):
def __init__(self,
value: JSON,
encoding: str='utf-8',
content_type: str='application/json',
dumps: JSONEncoder=json.dumps,
*args: Any,
**kwargs: Any) -> None:

super().__init__(
dumps(value).encode(encoding),
content_type=content_type, encoding=encoding, *args, **kwargs)


if TYPE_CHECKING:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add # pragma: no cover comment to the line for hinting coverage tool

from typing import AsyncIterator, AsyncIterable

_AsyncIterator = AsyncIterator[bytes]
_AsyncIterable = AsyncIterable[bytes]
else:
from collections.abc import AsyncIterable, AsyncIterator

_AsyncIterator = AsyncIterator
_AsyncIterable = AsyncIterable


class AsyncIterablePayload(Payload):

def __init__(self, value, *args, **kwargs):
_iter = None # type: Optional[_AsyncIterator]

def __init__(self,
value: _AsyncIterable,
*args: Any,
**kwargs: Any) -> None:
if not isinstance(value, AsyncIterable):
raise TypeError("value argument must support "
"collections.abc.AsyncIterablebe interface, "
Expand All @@ -314,20 +387,21 @@ def __init__(self, value, *args, **kwargs):

self._iter = value.__aiter__()

async def write(self, writer):
try:
# iter is not None check prevents rare cases
# when the case iterable is used twice
while True:
chunk = await self._iter.__anext__()
await writer.write(chunk)
except StopAsyncIteration:
self._iter = None
async def write(self, writer: AbstractStreamWriter) -> None:
if self._iter:
try:
# iter is not None check prevents rare cases
# when the case iterable is used twice
while True:
chunk = await self._iter.__anext__()
await writer.write(chunk)
except StopAsyncIteration:
self._iter = None


class StreamReaderPayload(AsyncIterablePayload):

def __init__(self, value, *args, **kwargs):
def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
super().__init__(value.iter_any(), *args, **kwargs)


Expand Down
Loading