diff --git a/CHANGES/3049.feature b/CHANGES/3049.feature new file mode 100644 index 00000000000..e0b0095ee9f --- /dev/null +++ b/CHANGES/3049.feature @@ -0,0 +1 @@ +Add type hints \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index a24028ff6ce..05084efddb9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,6 +8,7 @@ graft docs graft examples graft tests recursive-include vendor * +global-include aiohttp *.pyi global-exclude *.pyc global-exclude *.pyd global-exclude *.so diff --git a/Makefile b/Makefile index 09fd4f807dc..c2474fd6717 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,12 @@ flake: .flake check_changes: @./tools/check_changes.py -.develop: .install-deps $(shell find aiohttp -type f) .flake check_changes +mypy: .flake + if python -c "import sys; sys.exit(sys.implementation.name!='cpython')"; then \ + mypy aiohttp tests; \ + fi + +.develop: .install-deps $(shell find aiohttp -type f) .flake check_changes mypy @pip install -e . @touch .develop @@ -112,7 +117,4 @@ install: @pip install -U pip @pip install -Ur requirements/dev.txt -mypy: - mypy aiohttp tests - .PHONY: all build flake test vtest cov clean doc diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 27aca1f7c76..c7e5b34627b 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -4,6 +4,7 @@ from . import hdrs # noqa from .client import * # noqa +from .client import ClientSession, ServerFingerprintMismatch # noqa from .cookiejar import * # noqa from .formdata import * # noqa from .helpers import * # noqa @@ -21,7 +22,7 @@ from .worker import GunicornWebWorker, GunicornUVLoopWebWorker # noqa workers = ('GunicornWebWorker', 'GunicornUVLoopWebWorker') except ImportError: # pragma: no cover - workers = () + workers = () # type: ignore __all__ = (client.__all__ + # noqa diff --git a/aiohttp/_helpers.pyi b/aiohttp/_helpers.pyi new file mode 100644 index 00000000000..59608e15889 --- /dev/null +++ b/aiohttp/_helpers.pyi @@ -0,0 +1,8 @@ +from typing import Any + +class reify: + def __init__(self, wrapped: Any) -> None: ... + + def __get__(self, inst: Any, owner: Any) -> Any: ... + + def __set__(self, inst: Any, value: Any) -> None: ... diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index f16e4220c47..1bdacb53dc4 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -308,7 +308,7 @@ cdef class HttpParser: return messages, False, b'' -cdef class HttpRequestParserC(HttpParser): +cdef class HttpRequestParser(HttpParser): def __init__(self, protocol, loop, timer=None, size_t max_line_size=8190, size_t max_headers=32768, @@ -335,7 +335,7 @@ cdef class HttpRequestParserC(HttpParser): self._buf.clear() -cdef class HttpResponseParserC(HttpParser): +cdef class HttpResponseParser(HttpParser): def __init__(self, protocol, loop, timer=None, size_t max_line_size=8190, size_t max_headers=32768, diff --git a/aiohttp/client.py b/aiohttp/client.py index 28f8ef5125b..28d0f114266 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -92,7 +92,8 @@ def __init__(self, *, connector=None, loop=None, cookies=None, request_class=ClientRequest, response_class=ClientResponse, ws_response_class=ClientWebSocketResponse, version=http.HttpVersion11, - cookie_jar=None, connector_owner=True, raise_for_status=False, + cookie_jar=None, connector_owner=True, + raise_for_status=False, read_timeout=sentinel, conn_timeout=None, timeout=sentinel, auto_decompress=True, trust_env=False, diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 9eacc4e318d..207aa13108a 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -7,7 +7,7 @@ try: import ssl except ImportError: # pragma: no cover - ssl = None + ssl = None # type: ignore __all__ = ( @@ -203,24 +203,24 @@ class ClientSSLError(ClientConnectorError): if ssl is not None: - certificate_errors = (ssl.CertificateError,) - certificate_errors_bases = (ClientSSLError, ssl.CertificateError,) + cert_errors = (ssl.CertificateError,) + cert_errors_bases = (ClientSSLError, ssl.CertificateError,) ssl_errors = (ssl.SSLError,) ssl_error_bases = (ClientSSLError, ssl.SSLError) else: # pragma: no cover - certificate_errors = tuple() - certificate_errors_bases = (ClientSSLError, ValueError,) + cert_errors = tuple() + cert_errors_bases = (ClientSSLError, ValueError,) ssl_errors = tuple() ssl_error_bases = (ClientSSLError,) -class ClientConnectorSSLError(*ssl_error_bases): +class ClientConnectorSSLError(*ssl_error_bases): # type: ignore """Response ssl error.""" -class ClientConnectorCertificateError(*certificate_errors_bases): +class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore """Response certificate error.""" def __init__(self, connection_key, certificate_error): diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 6fe8f16ac63..6d1ff7bfc2f 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -27,7 +27,7 @@ try: import ssl except ImportError: # pragma: no cover - ssl = None + ssl = None # type: ignore try: import cchardet as chardet diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 24507c16027..aa4dfd7a467 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -18,7 +18,7 @@ ClientConnectorError, ClientConnectorSSLError, ClientHttpProxyError, ClientProxyConnectionError, - ServerFingerprintMismatch, certificate_errors, + ServerFingerprintMismatch, cert_errors, ssl_errors) from .client_proto import ResponseHandler from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params @@ -30,7 +30,7 @@ try: import ssl except ImportError: # pragma: no cover - ssl = None + ssl = None # type: ignore __all__ = ('BaseConnector', 'TCPConnector', 'UnixConnector') @@ -820,7 +820,7 @@ async def _wrap_create_connection(self, *args, try: with CeilTimeout(timeout.sock_connect): return await self._loop.create_connection(*args, **kwargs) - except certificate_errors as exc: + except cert_errors as exc: raise ClientConnectorCertificateError( req.connection_key, exc) from exc except ssl_errors as exc: diff --git a/aiohttp/frozenlist.py b/aiohttp/frozenlist.py index 59caad11f83..2aaea64739e 100644 --- a/aiohttp/frozenlist.py +++ b/aiohttp/frozenlist.py @@ -4,15 +4,8 @@ from .helpers import NO_EXTENSIONS -if not NO_EXTENSIONS: - try: - from aiohttp._frozenlist import FrozenList - except ImportError: # pragma: no cover - FrozenList = None - - @total_ordering -class PyFrozenList(MutableSequence): +class FrozenList(MutableSequence): __slots__ = ('_frozen', '_items') @@ -69,5 +62,11 @@ def __repr__(self): self._items) -if NO_EXTENSIONS or FrozenList is None: - FrozenList = PyFrozenList +PyFrozenList = FrozenList + +try: + from aiohttp._frozenlist import FrozenList as CFrozenList # type: ignore + if not NO_EXTENSIONS: + FrozenList = CFrozenList # type: ignore +except ImportError: # pragma: no cover + pass diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 8670e568fb3..32927e4515e 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -14,10 +14,11 @@ import time import weakref from collections import namedtuple -from collections.abc import Mapping +from collections.abc import Mapping as ABCMapping from contextlib import suppress from math import ceil from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple # noqa from urllib.parse import quote from urllib.request import getproxies @@ -41,8 +42,8 @@ idna_ssl.patch_match_hostname() -sentinel = object() -NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')) +sentinel = object() # type: Any +NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')) # type: bool # N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr # for compatibility with older versions @@ -59,10 +60,10 @@ coroutines = asyncio.coroutines -old_debug = coroutines._DEBUG +old_debug = coroutines._DEBUG # type: ignore # prevent "coroutine noop was never awaited" warning. -coroutines._DEBUG = False +coroutines._DEBUG = False # type: ignore @asyncio.coroutine @@ -70,13 +71,15 @@ def noop(*args, **kwargs): return -coroutines._DEBUG = old_debug +coroutines._DEBUG = old_debug # type: ignore class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])): """Http basic authentication helper.""" - def __new__(cls, login, password='', encoding='latin1'): + def __new__(cls, login: str, + password: str='', + encoding: str='latin1') -> 'BasicAuth': if login is None: raise ValueError('None is not allowed as login value') @@ -90,7 +93,7 @@ def __new__(cls, login, password='', encoding='latin1'): return super().__new__(cls, login, password, encoding) @classmethod - def decode(cls, auth_header, encoding='latin1'): + def decode(cls, auth_header: str, encoding: str='latin1') -> 'BasicAuth': """Create a BasicAuth object from an Authorization HTTP header.""" split = auth_header.strip().split(' ') if len(split) == 2: @@ -110,7 +113,8 @@ def decode(cls, auth_header, encoding='latin1'): return cls(username, password, encoding=encoding) @classmethod - def from_url(cls, url, *, encoding='latin1'): + def from_url(cls, url: URL, + *, encoding: str='latin1') -> Optional['BasicAuth']: """Create BasicAuth from url.""" if not isinstance(url, URL): raise TypeError("url should be yarl.URL instance") @@ -118,13 +122,13 @@ def from_url(cls, url, *, encoding='latin1'): return None return cls(url.user, url.password or '', encoding=encoding) - def encode(self): + def encode(self) -> str: """Encode credentials.""" creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding) return 'Basic %s' % base64.b64encode(creds).decode(self.encoding) -def strip_auth_from_url(url): +def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: auth = BasicAuth.from_url(url) if auth is None: return url, None @@ -293,6 +297,9 @@ def content_disposition_header(disptype, quote_fields=True, **params): return value +KeyMethod = namedtuple('KeyMethod', 'key method') + + class AccessLogger(AbstractAccessLogger): """Helper object to log access. @@ -336,9 +343,7 @@ class AccessLogger(AbstractAccessLogger): LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"' FORMAT_RE = re.compile(r'%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)') CLEANUP_RE = re.compile(r'(%[^s])') - _FORMAT_CACHE = {} - - KeyMethod = namedtuple('KeyMethod', 'key method') + _FORMAT_CACHE = {} # type: Dict[str, Tuple[str, List[KeyMethod]]] def __init__(self, logger, log_format=LOG_FORMAT): """Initialise the logger. @@ -390,7 +395,7 @@ def compile_format(self, log_format): m = getattr(AccessLogger, '_format_%s' % atom[2]) m = functools.partial(m, atom[1]) - methods.append(self.KeyMethod(format_key, m)) + methods.append(KeyMethod(format_key, m)) log_format = self.FORMAT_RE.sub(r'%s', log_format) log_format = self.CLEANUP_RE.sub(r'%\1', log_format) @@ -515,7 +520,7 @@ def __set__(self, inst, value): try: from ._helpers import reify as reify_c if not NO_EXTENSIONS: - reify = reify_c + reify = reify_c # type: ignore except ImportError: pass @@ -716,25 +721,25 @@ def _parse_content_type(self, raw): self._content_type, self._content_dict = cgi.parse_header(raw) @property - def content_type(self, *, _CONTENT_TYPE=hdrs.CONTENT_TYPE): + def content_type(self): """The value of content part for Content-Type HTTP header.""" - raw = self._headers.get(_CONTENT_TYPE) + raw = self._headers.get(hdrs.CONTENT_TYPE) if self._stored_content_type != raw: self._parse_content_type(raw) return self._content_type @property - def charset(self, *, _CONTENT_TYPE=hdrs.CONTENT_TYPE): + def charset(self): """The value of charset part for Content-Type HTTP header.""" - raw = self._headers.get(_CONTENT_TYPE) + raw = self._headers.get(hdrs.CONTENT_TYPE) if self._stored_content_type != raw: self._parse_content_type(raw) return self._content_dict.get('charset') @property - def content_length(self, *, _CONTENT_LENGTH=hdrs.CONTENT_LENGTH): + def content_length(self): """The value of Content-Length HTTP header.""" - content_length = self._headers.get(_CONTENT_LENGTH) + content_length = self._headers.get(hdrs.CONTENT_LENGTH) if content_length: return int(content_length) @@ -750,7 +755,7 @@ def set_exception(fut, exc): fut.set_exception(exc) -class ChainMapProxy(Mapping): +class ChainMapProxy(ABCMapping): __slots__ = ('_maps',) def __init__(self, maps): diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index c35fcf53fa5..3683129c90b 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -352,7 +352,7 @@ def parse_headers(self, lines): return headers, raw_headers, close_conn, encoding, upgrade, chunked -class HttpRequestParserPy(HttpParser): +class HttpRequestParser(HttpParser): """Read request status line. Exception .http_exceptions.BadStatusLine could be raised in case of any errors in status line. Returns RawRequestMessage. @@ -400,7 +400,7 @@ def parse_message(self, lines): close, compression, upgrade, chunked, URL(path)) -class HttpResponseParserPy(HttpParser): +class HttpResponseParser(HttpParser): """Read response status line and headers. BadStatusLine could be raised in case of any errors in status line. @@ -674,12 +674,13 @@ def end_http_chunk_receiving(self): self.out.end_http_chunk_receiving() -HttpRequestParser = HttpRequestParserPy -HttpResponseParser = HttpResponseParserPy +HttpRequestParserPy = HttpRequestParser +HttpResponseParserPy = HttpResponseParser try: - from ._http_parser import HttpRequestParserC, HttpResponseParserC + from ._http_parser import (HttpRequestParser as HttpRequestParserC, # type: ignore # noqa + HttpResponseParser as HttpResponseParserC) if not NO_EXTENSIONS: # pragma: no cover - HttpRequestParser = HttpRequestParserC - HttpResponseParser = HttpResponseParserC + HttpRequestParser = HttpRequestParserC # type: ignore + HttpResponseParser = HttpResponseParserC # type: ignore except ImportError: # pragma: no cover pass diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index 4c8983c744a..dd32b12b4f6 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -138,7 +138,7 @@ def _websocket_mask_python(mask, data): _websocket_mask = _websocket_mask_python else: try: - from ._websocket import _websocket_mask_cython + from ._websocket import _websocket_mask_cython # type: ignore _websocket_mask = _websocket_mask_cython except ImportError: # pragma: no cover _websocket_mask = _websocket_mask_python diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index 4bf4a67aebb..53f8bd6c062 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -1,9 +1,12 @@ """Http related parsers and protocol.""" +import asyncio import collections import zlib +from typing import Any, Awaitable, Callable, Optional, Union # noqa from .abc import AbstractStreamWriter +from .base_protocol import BaseProtocol from .helpers import NO_EXTENSIONS @@ -14,9 +17,16 @@ HttpVersion11 = HttpVersion(1, 1) +_T_Data = Union[bytes, bytearray, memoryview] +_T_OnChunkSent = Optional[Callable[[_T_Data], Awaitable[None]]] + + class StreamWriter(AbstractStreamWriter): - def __init__(self, protocol, loop, on_chunk_sent=None): + def __init__(self, + protocol: BaseProtocol, + loop: asyncio.AbstractEventLoop, + on_chunk_sent: _T_OnChunkSent = None) -> None: self._protocol = protocol self._transport = protocol.transport @@ -27,28 +37,28 @@ def __init__(self, protocol, loop, on_chunk_sent=None): self.output_size = 0 self._eof = False - self._compress = None + self._compress = None # type: Any self._drain_waiter = None - self._on_chunk_sent = on_chunk_sent + self._on_chunk_sent = on_chunk_sent # type: _T_OnChunkSent @property - def transport(self): + def transport(self) -> asyncio.Transport: return self._transport @property - def protocol(self): + def protocol(self) -> BaseProtocol: return self._protocol - def enable_chunking(self): + def enable_chunking(self) -> None: self.chunked = True - def enable_compression(self, encoding='deflate'): + def enable_compression(self, encoding: str='deflate') -> None: zlib_mode = (16 + zlib.MAX_WBITS if encoding == 'gzip' else -zlib.MAX_WBITS) self._compress = zlib.compressobj(wbits=zlib_mode) - def _write(self, chunk): + def _write(self, chunk) -> None: size = len(chunk) self.buffer_size += size self.output_size += size @@ -57,7 +67,7 @@ def _write(self, chunk): raise ConnectionResetError('Cannot write to closing transport') self._transport.write(chunk) - async def write(self, chunk, *, drain=True, LIMIT=0x10000): + async def write(self, chunk, *, drain=True, LIMIT=0x10000) -> None: """Writes chunk of data to a stream. write_eof() indicates end of stream. @@ -93,13 +103,13 @@ async def write(self, chunk, *, drain=True, LIMIT=0x10000): self.buffer_size = 0 await self.drain() - async def write_headers(self, status_line, headers): + async def write_headers(self, status_line, headers) -> None: """Write request/response status and headers.""" # status + headers buf = _serialize_headers(status_line, headers) self._write(buf) - async def write_eof(self, chunk=b''): + async def write_eof(self, chunk=b'') -> None: if self._eof: return @@ -130,7 +140,7 @@ async def write_eof(self, chunk=b''): self._eof = True self._transport = None - async def drain(self): + async def drain(self) -> None: """Flush the write buffer. The intended use is to write @@ -151,7 +161,8 @@ def _py_serialize_headers(status_line, headers): _serialize_headers = _py_serialize_headers try: - from ._http_writer import _serialize_headers as _c_serialize_headers + import aiohttp._http_writer as _http_writer # type: ignore + _c_serialize_headers = _http_writer._serialize_headers if not NO_EXTENSIONS: # pragma: no cover _serialize_headers = _c_serialize_headers except ImportError: diff --git a/aiohttp/py.typed b/aiohttp/py.typed new file mode 100644 index 00000000000..20a74394fc0 --- /dev/null +++ b/aiohttp/py.typed @@ -0,0 +1 @@ +Marker \ No newline at end of file diff --git a/aiohttp/streams.py b/aiohttp/streams.py index ac243ca279c..6f9931c46b9 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -1,8 +1,11 @@ import asyncio import collections +from typing import List # noqa +from typing import Awaitable, Callable, Optional, Tuple from .helpers import set_exception, set_result from .log import internal_logger +from .typedefs import Byteish __all__ = ( @@ -18,13 +21,13 @@ class EofStream(Exception): class AsyncStreamIterator: - def __init__(self, read_func): + def __init__(self, read_func: Callable[[], Awaitable[bytes]]) -> None: self.read_func = read_func - def __aiter__(self): + def __aiter__(self) -> 'AsyncStreamIterator': return self - async def __anext__(self): + async def __anext__(self) -> bytes: try: rv = await self.read_func() except EofStream: @@ -35,7 +38,7 @@ async def __anext__(self): class ChunkTupleAsyncStreamIterator(AsyncStreamIterator): - async def __anext__(self): + async def __anext__(self) -> bytes: rv = await self.read_func() if rv == (b'', False): raise StopAsyncIteration # NOQA @@ -44,32 +47,32 @@ async def __anext__(self): class AsyncStreamReaderMixin: - def __aiter__(self): - return AsyncStreamIterator(self.readline) + def __aiter__(self) -> AsyncStreamIterator: + return AsyncStreamIterator(self.readline) # type: ignore - def iter_chunked(self, n): + def iter_chunked(self, n: int) -> AsyncStreamIterator: """Returns an asynchronous iterator that yields chunks of size n. Python-3.5 available for Python 3.5+ only """ - return AsyncStreamIterator(lambda: self.read(n)) + return AsyncStreamIterator(lambda: self.read(n)) # type: ignore - def iter_any(self): + def iter_any(self) -> AsyncStreamIterator: """Returns an asynchronous iterator that yields all the available data as soon as it is received Python-3.5 available for Python 3.5+ only """ - return AsyncStreamIterator(self.readany) + return AsyncStreamIterator(self.readany) # type: ignore - def iter_chunks(self): + def iter_chunks(self) -> ChunkTupleAsyncStreamIterator: """Returns an asynchronous iterator that yields chunks of data as they are received by the server. The yielded objects are tuples of (bytes, bool) as returned by the StreamReader.readchunk method. Python-3.5 available for Python 3.5+ only """ - return ChunkTupleAsyncStreamIterator(self.readchunk) + return ChunkTupleAsyncStreamIterator(self.readchunk) # type: ignore class StreamReader(AsyncStreamReaderMixin): @@ -122,10 +125,10 @@ def __repr__(self): info.append('e=%r' % self._exception) return '<%s>' % ' '.join(info) - def exception(self): + def exception(self) -> Optional[BaseException]: return self._exception - def set_exception(self, exc): + def set_exception(self, exc: BaseException) -> None: self._exception = exc self._eof_callbacks.clear() @@ -139,7 +142,7 @@ def set_exception(self, exc): set_exception(waiter, exc) self._eof_waiter = None - def on_eof(self, callback): + def on_eof(self, callback: Callable[[], None]) -> None: if self._eof: try: callback() @@ -148,7 +151,7 @@ def on_eof(self, callback): else: self._eof_callbacks.append(callback) - def feed_eof(self): + def feed_eof(self) -> None: self._eof = True waiter = self._waiter @@ -169,15 +172,15 @@ def feed_eof(self): self._eof_callbacks.clear() - def is_eof(self): + def is_eof(self) -> bool: """Return True if 'feed_eof' was called.""" return self._eof - def at_eof(self): + def at_eof(self) -> bool: """Return True if the buffer is empty and 'feed_eof' was called.""" return self._eof and not self._buffer - async def wait_eof(self): + async def wait_eof(self) -> None: if self._eof: return @@ -188,7 +191,7 @@ async def wait_eof(self): finally: self._eof_waiter = None - def unread_data(self, data): + def unread_data(self, data: Byteish) -> None: """ rollback reading some data from stream, inserting it to buffer head. """ if not data: @@ -203,7 +206,7 @@ def unread_data(self, data): self._eof_counter = 0 # TODO: size is ignored, remove the param later - def feed_data(self, data, size=0): + def feed_data(self, data: Byteish, size: int=0) -> None: assert not self._eof, 'feed_data after feed_eof' if not data: @@ -222,11 +225,11 @@ def feed_data(self, data, size=0): not self._protocol._reading_paused): self._protocol.pause_reading() - def begin_http_chunk_receiving(self): + def begin_http_chunk_receiving(self) -> None: if self._http_chunk_splits is None: self._http_chunk_splits = [] - def end_http_chunk_receiving(self): + def end_http_chunk_receiving(self) -> None: if self._http_chunk_splits is None: raise RuntimeError("Called end_chunk_receiving without calling " "begin_chunk_receiving first") @@ -253,7 +256,7 @@ async def _wait(self, func_name): finally: self._waiter = None - async def readline(self): + async def readline(self) -> bytes: if self._exception is not None: raise self._exception @@ -283,7 +286,7 @@ async def readline(self): return b''.join(line) - async def read(self, n=-1): + async def read(self, n: int=-1) -> bytes: if self._exception is not None: raise self._exception @@ -320,7 +323,7 @@ async def read(self, n=-1): return self._read_nowait(n) - async def readany(self): + async def readany(self) -> bytes: if self._exception is not None: raise self._exception @@ -329,7 +332,7 @@ async def readany(self): return self._read_nowait(-1) - async def readchunk(self): + async def readchunk(self) -> Tuple[bytes, bool]: """Returns a tuple of (data, end_of_http_chunk). When chunked transfer encoding is used, end_of_http_chunk is a boolean indicating if the end of the data corresponds to the end of a HTTP chunk , otherwise it is @@ -359,11 +362,11 @@ async def readchunk(self): else: return (self._read_nowait_chunk(-1), False) - async def readexactly(self, n): + async def readexactly(self, n: int) -> bytes: if self._exception is not None: raise self._exception - blocks = [] + blocks = [] # type: List[bytes] while n > 0: block = await self.read(n) if not block: @@ -375,7 +378,7 @@ async def readexactly(self, n): return b''.join(blocks) - def read_nowait(self, n=-1): + def read_nowait(self, n: int=-1) -> bytes: # default was changed to be consistent with .read(-1) # # I believe the most users don't know about the method and @@ -427,49 +430,49 @@ def _read_nowait(self, n): class EmptyStreamReader(AsyncStreamReaderMixin): - def exception(self): + def exception(self) -> Optional[BaseException]: return None - def set_exception(self, exc): + def set_exception(self, exc: BaseException) -> None: pass - def on_eof(self, callback): + def on_eof(self, callback: Callable[[], None]): try: callback() except Exception: internal_logger.exception('Exception in eof callback') - def feed_eof(self): + def feed_eof(self) -> None: pass - def is_eof(self): + def is_eof(self) -> bool: return True - def at_eof(self): + def at_eof(self) -> bool: return True - async def wait_eof(self): + async def wait_eof(self) -> None: return - def feed_data(self, data): + def feed_data(self, data: Byteish, n: int=0) -> None: pass - async def readline(self): + async def readline(self) -> bytes: return b'' - async def read(self, n=-1): + async def read(self, n: int=-1) -> bytes: return b'' - async def readany(self): + async def readany(self) -> bytes: return b'' - async def readchunk(self): + async def readchunk(self) -> Tuple[bytes, bool]: return (b'', False) - async def readexactly(self, n): + async def readexactly(self, n: int) -> bytes: raise asyncio.streams.IncompleteReadError(b'', n) - def read_nowait(self): + def read_nowait(self) -> bytes: return b'' @@ -487,19 +490,19 @@ def __init__(self, *, loop=None): self._size = 0 self._buffer = collections.deque() - def __len__(self): + def __len__(self) -> int: return len(self._buffer) - def is_eof(self): + def is_eof(self) -> bool: return self._eof - def at_eof(self): + def at_eof(self) -> bool: return self._eof and not self._buffer - def exception(self): + def exception(self) -> Optional[BaseException]: return self._exception - def set_exception(self, exc): + def set_exception(self, exc: BaseException) -> None: self._eof = True self._exception = exc @@ -508,7 +511,7 @@ def set_exception(self, exc): set_exception(waiter, exc) self._waiter = None - def feed_data(self, data, size=0): + def feed_data(self, data: Byteish, size: int=0) -> None: self._size += size self._buffer.append((data, size)) @@ -517,7 +520,7 @@ def feed_data(self, data, size=0): self._waiter = None set_result(waiter, True) - def feed_eof(self): + def feed_eof(self) -> None: self._eof = True waiter = self._waiter @@ -525,7 +528,7 @@ def feed_eof(self): self._waiter = None set_result(waiter, False) - async def read(self): + async def read(self) -> bytes: if not self._buffer and not self._eof: assert not self._waiter self._waiter = self._loop.create_future() @@ -545,7 +548,7 @@ async def read(self): else: raise EofStream - def __aiter__(self): + def __aiter__(self) -> AsyncStreamIterator: return AsyncStreamIterator(self.read) @@ -554,19 +557,20 @@ class FlowControlDataQueue(DataQueue): It is a destination for parsed data.""" - def __init__(self, protocol, *, limit=DEFAULT_LIMIT, loop=None): + def __init__(self, protocol, *, limit: int=DEFAULT_LIMIT, loop: + Optional[asyncio.AbstractEventLoop]=None) -> None: super().__init__(loop=loop) self._protocol = protocol self._limit = limit * 2 - def feed_data(self, data, size): + def feed_data(self, data: Byteish, size: int=0) -> None: super().feed_data(data, size) if self._size > self._limit and not self._protocol._reading_paused: self._protocol.pause_reading() - async def read(self): + async def read(self) -> bytes: try: return await super().read() finally: diff --git a/aiohttp/tcp_helpers.py b/aiohttp/tcp_helpers.py index 3a016901c9d..d703dc357a2 100644 --- a/aiohttp/tcp_helpers.py +++ b/aiohttp/tcp_helpers.py @@ -1,31 +1,34 @@ """Helper methods to tune a TCP connection""" +import asyncio import socket from contextlib import suppress +from typing import Optional # noqa __all__ = ('tcp_keepalive', 'tcp_nodelay', 'tcp_cork') if hasattr(socket, 'TCP_CORK'): # pragma: no cover - CORK = socket.TCP_CORK + CORK = socket.TCP_CORK # type: Optional[int] elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover - CORK = socket.TCP_NOPUSH + CORK = socket.TCP_NOPUSH # type: ignore else: # pragma: no cover CORK = None if hasattr(socket, 'SO_KEEPALIVE'): - def tcp_keepalive(transport): + def tcp_keepalive(transport: asyncio.Transport) -> None: sock = transport.get_extra_info('socket') if sock is not None: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) else: - def tcp_keepalive(transport): # pragma: no cover + def tcp_keepalive( + transport: asyncio.Transport) -> None: # pragma: no cover pass -def tcp_nodelay(transport, value): +def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None: sock = transport.get_extra_info('socket') if sock is None: @@ -42,7 +45,7 @@ def tcp_nodelay(transport, value): socket.IPPROTO_TCP, socket.TCP_NODELAY, value) -def tcp_cork(transport, value): +def tcp_cork(transport: asyncio.Transport, value: bool) -> None: sock = transport.get_extra_info('socket') if CORK is None: diff --git a/aiohttp/typedefs.py b/aiohttp/typedefs.py new file mode 100644 index 00000000000..3a620e49e3f --- /dev/null +++ b/aiohttp/typedefs.py @@ -0,0 +1,12 @@ +from typing import Any, Callable, Mapping, Tuple, Union # noqa + +from multidict import CIMultiDict, CIMultiDictProxy +from yarl import URL + + +# type helpers +Byteish = Union[bytes, bytearray, memoryview] +JSONDecoder = Callable[[str], Any] +LooseHeaders = Union[Mapping, CIMultiDict, CIMultiDictProxy] +RawHeaders = Tuple[Tuple[bytes, bytes], ...] +StrOrURL = Union[str, URL] diff --git a/aiohttp/web_exceptions.py b/aiohttp/web_exceptions.py index 2127c90614b..4d948f4ee55 100644 --- a/aiohttp/web_exceptions.py +++ b/aiohttp/web_exceptions.py @@ -71,7 +71,7 @@ class HTTPException(Response, Exception): # You should set in subclasses: # status = 200 - status_code = None + status_code = -1 empty_body = False def __init__(self, *, headers=None, reason=None, diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 6c0c454d9ad..78213665bf3 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -1,6 +1,5 @@ import asyncio import asyncio.streams -import http.server import traceback import warnings from collections import deque @@ -9,10 +8,11 @@ import yarl -from . import helpers, http +from . import helpers from .base_protocol import BaseProtocol from .helpers import CeilTimeout -from .http import HttpProcessingError, HttpRequestParser, StreamWriter +from .http import (HttpProcessingError, HttpRequestParser, HttpVersion10, + RawRequestMessage, StreamWriter) from .log import access_logger, server_logger from .streams import EMPTY_PAYLOAD from .tcp_helpers import tcp_cork, tcp_keepalive, tcp_nodelay @@ -23,8 +23,8 @@ __all__ = ('RequestHandler', 'RequestPayloadError', 'PayloadAccessError') -ERROR = http.RawRequestMessage( - 'UNKNOWN', '/', http.HttpVersion10, {}, +ERROR = RawRequestMessage( + 'UNKNOWN', '/', HttpVersion10, {}, {}, True, False, False, False, yarl.URL('/')) diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 907d3b00e31..09b518e4f0f 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -12,6 +12,7 @@ from email.utils import parsedate from http.cookies import SimpleCookie from types import MappingProxyType +from typing import Any, Dict, Mapping, Optional, Tuple, cast # noqa from urllib.parse import parse_qsl import attr @@ -20,13 +21,18 @@ from . import hdrs, multipart from .helpers import DEBUG, ChainMapProxy, HeadersMixin, reify, sentinel -from .streams import EmptyStreamReader +from .streams import EmptyStreamReader, StreamReader +from .typedefs import JSONDecoder, LooseHeaders, RawHeaders, StrOrURL from .web_exceptions import HTTPRequestEntityTooLarge +from .web_urldispatcher import UrlMappingMatchInfo __all__ = ('BaseRequest', 'FileField', 'Request') +DEFAULT_JSON_DECODER = json.loads + + @attr.s(frozen=True, slots=True) class FileField: name = attr.ib(type=str) @@ -108,9 +114,10 @@ def __init__(self, message, payload, protocol, payload_writer, task, if remote is not None: self._cache['remote'] = remote - def clone(self, *, method=sentinel, rel_url=sentinel, - headers=sentinel, scheme=sentinel, host=sentinel, - remote=sentinel): + def clone(self, *, method: str=sentinel, rel_url: StrOrURL=sentinel, + headers: LooseHeaders=sentinel, scheme: str=sentinel, + host: str=sentinel, + remote: str=sentinel) -> 'BaseRequest': """Clone itself with replacement some attributes. Creates and returns a new instance of Request object. If no parameters @@ -123,13 +130,13 @@ def clone(self, *, method=sentinel, rel_url=sentinel, raise RuntimeError("Cannot clone request " "after reading it's content") - dct = {} + dct = {} # type: Dict[str, Any] if method is not sentinel: dct['method'] = method if rel_url is not sentinel: - rel_url = URL(rel_url) - dct['url'] = rel_url - dct['path'] = str(rel_url) + new_url = URL(rel_url) + dct['url'] = new_url + dct['path'] = str(new_url) if headers is not sentinel: # a copy semantic dct['headers'] = CIMultiDictProxy(CIMultiDict(headers)) @@ -183,11 +190,11 @@ def message(self): return self._message @reify - def rel_url(self): + def rel_url(self) -> URL: return self._rel_url @reify - def loop(self): + def loop(self) -> asyncio.AbstractEventLoop: return self._loop # MutableMapping API @@ -210,7 +217,7 @@ def __iter__(self): ######## @reify - def secure(self): + def secure(self) -> bool: """A bool indicating if the request is handled with SSL.""" return self.scheme == 'https' @@ -277,7 +284,7 @@ def forwarded(self): return tuple(elems) @reify - def scheme(self): + def scheme(self) -> str: """A string representing the scheme of the request. Hostname is resolved in this order: @@ -293,7 +300,7 @@ def scheme(self): return 'http' @reify - def method(self): + def method(self) -> str: """Read only property for getting HTTP method. The value is upper-cased str like 'GET', 'POST', 'PUT' etc. @@ -301,7 +308,7 @@ def method(self): return self._method @reify - def version(self): + def version(self) -> Tuple[int, int]: """Read only property for getting HTTP version of request. Returns aiohttp.protocol.HttpVersion instance. @@ -309,7 +316,7 @@ def version(self): return self._version @reify - def host(self): + def host(self) -> str: """Hostname of the request. Hostname is resolved in this order: @@ -325,7 +332,7 @@ def host(self): return socket.getfqdn() @reify - def remote(self): + def remote(self) -> Optional[str]: """Remote IP of client initiated HTTP request. The IP is resolved in this order: @@ -342,12 +349,12 @@ def remote(self): return peername @reify - def url(self): + def url(self) -> URL: url = URL.build(scheme=self.scheme, host=self.host) return url.join(self._rel_url) @reify - def path(self): + def path(self) -> str: """The URL including *PATH INFO* without the host or scheme. E.g., ``/app/blog`` @@ -355,7 +362,7 @@ def path(self): return self._rel_url.path @reify - def path_qs(self): + def path_qs(self) -> str: """The URL including PATH_INFO and the query string. E.g, /app/blog?id=10 @@ -363,7 +370,7 @@ def path_qs(self): return str(self._rel_url) @reify - def raw_path(self): + def raw_path(self) -> str: """ The URL including raw *PATH INFO* without the host or scheme. Warning, the path is unquoted and may contains non valid URL characters @@ -372,12 +379,12 @@ def raw_path(self): return self._message.path @reify - def query(self): + def query(self) -> MultiDict: """A multidict with all the variables in the query string.""" return self._rel_url.query @reify - def query_string(self): + def query_string(self) -> str: """The query string in the URL. E.g., id=10 @@ -385,17 +392,17 @@ def query_string(self): return self._rel_url.query_string @reify - def headers(self): + def headers(self) -> CIMultiDictProxy: """A case-insensitive multidict proxy with all headers.""" return self._headers @reify - def raw_headers(self): + def raw_headers(self) -> RawHeaders: """A sequence of pars for all headers.""" return self._message.raw_headers @staticmethod - def _http_date(_date_str): + def _http_date(_date_str) -> Optional[datetime.datetime]: """Process a date string, return a datetime object """ if _date_str is not None: @@ -406,54 +413,53 @@ def _http_date(_date_str): return None @reify - def if_modified_since(self, _IF_MODIFIED_SINCE=hdrs.IF_MODIFIED_SINCE): + def if_modified_since(self) -> Optional[datetime.datetime]: """The value of If-Modified-Since HTTP header, or None. This header is represented as a `datetime` object. """ - return self._http_date(self.headers.get(_IF_MODIFIED_SINCE)) + return self._http_date(self.headers.get(hdrs.IF_MODIFIED_SINCE)) @reify - def if_unmodified_since(self, - _IF_UNMODIFIED_SINCE=hdrs.IF_UNMODIFIED_SINCE): + def if_unmodified_since(self) -> Optional[datetime.datetime]: """The value of If-Unmodified-Since HTTP header, or None. This header is represented as a `datetime` object. """ - return self._http_date(self.headers.get(_IF_UNMODIFIED_SINCE)) + return self._http_date(self.headers.get(hdrs.IF_UNMODIFIED_SINCE)) @reify - def if_range(self, _IF_RANGE=hdrs.IF_RANGE): + def if_range(self) -> Optional[datetime.datetime]: """The value of If-Range HTTP header, or None. This header is represented as a `datetime` object. """ - return self._http_date(self.headers.get(_IF_RANGE)) + return self._http_date(self.headers.get(hdrs.IF_RANGE)) @reify - def keep_alive(self): + def keep_alive(self) -> bool: """Is keepalive enabled by client?""" return not self._message.should_close @reify - def cookies(self): + def cookies(self) -> Mapping[str, str]: """Return request cookies. A read-only dictionary-like object. """ raw = self.headers.get(hdrs.COOKIE, '') - parsed = SimpleCookie(raw) + parsed = SimpleCookie(raw) # type: ignore return MappingProxyType( {key: val.value for key, val in parsed.items()}) @reify - def http_range(self, *, _RANGE=hdrs.RANGE): + def http_range(self): """The content of Range HTTP header. Return a slice instance. """ - rng = self._headers.get(_RANGE) + rng = self._headers.get(hdrs.RANGE) start, end = None, None if rng is not None: try: @@ -483,12 +489,12 @@ def http_range(self, *, _RANGE=hdrs.RANGE): return slice(start, end, 1) @reify - def content(self): + def content(self) -> StreamReader: """Return raw payload stream.""" return self._payload @property - def has_body(self): + def has_body(self) -> bool: """Return True if request's HTTP BODY can be read, False otherwise.""" warnings.warn( "Deprecated, use .can_read_body #2005", @@ -496,16 +502,16 @@ def has_body(self): return not self._payload.at_eof() @property - def can_read_body(self): + def can_read_body(self) -> bool: """Return True if request's HTTP BODY can be read, False otherwise.""" return not self._payload.at_eof() @reify - def body_exists(self): + def body_exists(self) -> bool: """Return True if request has HTTP BODY, False otherwise.""" return type(self._payload) is not EmptyStreamReader - async def release(self): + async def release(self) -> None: """Release request. Eat unread part of HTTP BODY if present. @@ -513,7 +519,7 @@ async def release(self): while not self._payload.at_eof(): await self._payload.readany() - async def read(self): + async def read(self) -> bytes: """Read request body if present. Returns bytes object with full request content. @@ -531,13 +537,13 @@ async def read(self): self._read_bytes = bytes(body) return self._read_bytes - async def text(self): + async def text(self) -> str: """Return BODY as text using encoding from .charset.""" bytes_body = await self.read() encoding = self.charset or 'utf-8' return bytes_body.decode(encoding) - async def json(self, *, loads=json.loads): + async def json(self, *, loads: JSONDecoder=DEFAULT_JSON_DECODER) -> Any: """Return BODY as JSON.""" body = await self.text() return loads(body) @@ -546,7 +552,7 @@ async def multipart(self, *, reader=multipart.MultipartReader): """Return async iterator to process BODY as multipart.""" return reader(self._headers, self._payload) - async def post(self): + async def post(self) -> MultiDictProxy: """Return POST parameters.""" if self._post is not None: return self._post @@ -561,7 +567,7 @@ async def post(self): self._post = MultiDictProxy(MultiDict()) return self._post - out = MultiDict() + out = MultiDict() # type: MultiDict if content_type == 'multipart/form-data': multipart = await self.multipart() @@ -587,7 +593,8 @@ async def post(self): tmp.seek(0) ff = FileField(field.name, field.filename, - tmp, content_type, field.headers) + cast(io.BufferedReader, tmp), + content_type, field.headers) out.add(field.name, ff) else: value = await field.read(decode=True) @@ -631,12 +638,14 @@ class Request(BaseRequest): ATTRS = BaseRequest.ATTRS | frozenset(['_match_info']) - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # matchdict, route_name, handler # or information about traversal lookup - self._match_info = None # initialized after route resolving + + # initialized after route resolving + self._match_info = None # type: Optional[UrlMappingMatchInfo] if DEBUG: def __setattr__(self, name, val): @@ -648,31 +657,39 @@ def __setattr__(self, name, val): stacklevel=2) super().__setattr__(name, val) - def clone(self, *, method=sentinel, rel_url=sentinel, - headers=sentinel, scheme=sentinel, host=sentinel, - remote=sentinel): + def clone(self, *, method: str=sentinel, rel_url: + StrOrURL=sentinel, headers: LooseHeaders=sentinel, + scheme: str=sentinel, host: str=sentinel, remote: + str=sentinel) -> 'Request': ret = super().clone(method=method, rel_url=rel_url, headers=headers, scheme=scheme, host=host, remote=remote) - ret._match_info = self._match_info - return ret + new_ret = cast(Request, ret) + new_ret._match_info = self._match_info + return new_ret @reify - def match_info(self): + def match_info(self) -> Optional[UrlMappingMatchInfo]: """Result of route resolving.""" return self._match_info @property def app(self): """Application instance.""" - return self._match_info.current_app + match_info = self._match_info + if match_info is None: + return None + return match_info.current_app @property - def config_dict(self): - lst = self._match_info.apps + def config_dict(self) -> ChainMapProxy: + match_info = self._match_info + if match_info is None: + return ChainMapProxy([]) + lst = match_info.apps app = self.app idx = lst.index(app) sublist = list(reversed(lst[:idx + 1])) diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index ad9ea0c2a51..f3c2a5311a8 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -237,12 +237,12 @@ def charset(self, value): self._generate_content_type_header() @property - def last_modified(self, _LAST_MODIFIED=hdrs.LAST_MODIFIED): + def last_modified(self): """The value of Last-Modified HTTP header, or None. This header is represented as a `datetime` object. """ - httpdate = self.headers.get(_LAST_MODIFIED) + httpdate = self.headers.get(hdrs.LAST_MODIFIED) if httpdate is not None: timetuple = parsedate(httpdate) if timetuple is not None: diff --git a/aiohttp/worker.py b/aiohttp/worker.py index 0f70d262472..a4b32c48752 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -18,7 +18,7 @@ try: import ssl except ImportError: # pragma: no cover - ssl = None + ssl = None # type: ignore __all__ = ('GunicornWebWorker', diff --git a/setup.cfg b/setup.cfg index 50c48cca1c5..82c35206470 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,3 +49,30 @@ ignore_missing_imports = true [mypy-async_generator] ignore_missing_imports = true + + +[mypy-aiodns] +ignore_missing_imports = true + + +[mypy-gunicorn.config] +ignore_missing_imports = true + +[mypy-gunicorn.workers] +ignore_missing_imports = true + + +[mypy-brotli] +ignore_missing_imports = true + + +[mypy-chardet] +ignore_missing_imports = true + + +[mypy-cchardet] +ignore_missing_imports = true + + +[mypy-idna_ssl] +ignore_missing_imports = true diff --git a/tests/conftest.py b/tests/conftest.py index ba454ada3d5..763a254ae05 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ +import pathlib +import shutil import tempfile import pytest -from py import path pytest_plugins = ['aiohttp.pytest_plugin', 'pytester'] @@ -12,6 +13,6 @@ def shorttmpdir(): """Provides a temporary directory with a shorter file system path than the tmpdir fixture. """ - tmpdir = path.local(tempfile.mkdtemp()) + tmpdir = pathlib.Path(tempfile.mkdtemp()) yield tmpdir - tmpdir.remove(rec=1) + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/tests/test_frozenlist.py b/tests/test_frozenlist.py index 1654788cdea..289f3e35f07 100644 --- a/tests/test_frozenlist.py +++ b/tests/test_frozenlist.py @@ -6,7 +6,7 @@ class FrozenListMixin: - FrozenList = None + FrozenList = NotImplemented SKIP_METHODS = {'__abstractmethods__', '__slots__'} diff --git a/tests/test_helpers.py b/tests/test_helpers.py index c3014d004c6..a20a3bb9daf 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -7,6 +7,7 @@ from unittest import mock import pytest +from multidict import MultiDict from yarl import URL from aiohttp import helpers @@ -19,25 +20,23 @@ # ------------------- parse_mimetype ---------------------------------- @pytest.mark.parametrize('mimetype, expected', [ - ('', helpers.MimeType('', '', '', {})), - ('*', helpers.MimeType('*', '*', '', {})), - ('application/json', helpers.MimeType('application', 'json', '', {})), - ( - 'application/json; charset=utf-8', - helpers.MimeType('application', 'json', '', {'charset': 'utf-8'}) - ), - ( - '''application/json; charset=utf-8;''', - helpers.MimeType('application', 'json', '', {'charset': 'utf-8'}) - ), - ( - 'ApPlIcAtIoN/JSON;ChaRseT="UTF-8"', - helpers.MimeType('application', 'json', '', {'charset': 'UTF-8'}) - ), + ('', helpers.MimeType('', '', '', MultiDict())), + ('*', helpers.MimeType('*', '*', '', MultiDict())), + ('application/json', + helpers.MimeType('application', 'json', '', MultiDict())), + ('application/json; charset=utf-8', + helpers.MimeType('application', 'json', '', + MultiDict({'charset': 'utf-8'}))), + ('''application/json; charset=utf-8;''', + helpers.MimeType('application', 'json', '', + MultiDict({'charset': 'utf-8'}))), + ('ApPlIcAtIoN/JSON;ChaRseT="UTF-8"', + helpers.MimeType('application', 'json', '', + MultiDict({'charset': 'UTF-8'}))), ('application/rss+xml', - helpers.MimeType('application', 'rss', 'xml', {})), + helpers.MimeType('application', 'rss', 'xml', MultiDict())), ('text/plain;base64', - helpers.MimeType('text', 'plain', '', {'base64': ''})) + helpers.MimeType('text', 'plain', '', MultiDict({'base64': ''}))) ]) def test_parse_mimetype(mimetype, expected): result = helpers.parse_mimetype(mimetype) diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index c5f25871b68..5422dd2390d 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -24,8 +24,8 @@ try: from aiohttp import _http_parser - REQUEST_PARSERS.append(_http_parser.HttpRequestParserC) - RESPONSE_PARSERS.append(_http_parser.HttpResponseParserC) + REQUEST_PARSERS.append(_http_parser.HttpRequestParser) + RESPONSE_PARSERS.append(_http_parser.HttpResponseParser) except ImportError: # pragma: no cover pass diff --git a/tests/test_run_app.py b/tests/test_run_app.py index d25fbe09a75..ce95fd8bf73 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -21,7 +21,7 @@ if _has_unix_domain_socks: _abstract_path_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: - _abstract_path_sock.bind(b"\x00" + uuid4().hex.encode('ascii')) + _abstract_path_sock.bind(b"\x00" + uuid4().hex.encode('ascii')) # type: ignore # noqa except FileNotFoundError: _abstract_path_failed = True else: @@ -137,7 +137,7 @@ def test_run_app_close_loop(patched_loop): ] mock_socket = mock.Mock(getsockname=lambda: ('mock-socket', 123)) mixed_bindings_tests = ( - ( + ( # type: ignore "Nothing Specified", {}, [mock.call(mock.ANY, '0.0.0.0', 8080, ssl=None, backlog=128, @@ -354,7 +354,7 @@ def test_run_app_custom_backlog_unix(patched_loop): def test_run_app_http_unix_socket(patched_loop, shorttmpdir): app = web.Application() - sock_path = str(shorttmpdir.join('socket.sock')) + sock_path = str(shorttmpdir / 'socket.sock') printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, path=sock_path, print=printer) @@ -367,7 +367,7 @@ def test_run_app_http_unix_socket(patched_loop, shorttmpdir): def test_run_app_https_unix_socket(patched_loop, shorttmpdir): app = web.Application() - sock_path = str(shorttmpdir.join('socket.sock')) + sock_path = str(shorttmpdir / 'socket.sock') ssl_context = ssl.create_default_context() printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, path=sock_path, ssl_context=ssl_context, print=printer) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index b0545836e2b..0e973bb4df0 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -20,7 +20,7 @@ try: import ssl except ImportError: - ssl = False + ssl = None # type: ignore @pytest.fixture diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 203d44dfdf2..5ce8890a553 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -12,7 +12,7 @@ try: import ssl except ImportError: - ssl = False + ssl = None # type: ignore @pytest.fixture(params=['sendfile', 'fallback'], ids=['sendfile', 'fallback']) diff --git a/tests/test_worker.py b/tests/test_worker.py index be268aaf27e..d385d03e1fb 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -43,13 +43,13 @@ def __init__(self): self.wsgi = web.Application() -class AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker): +class AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker): # type: ignore # noqa pass PARAMS = [AsyncioWorker] if uvloop is not None: - class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker): + class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker): # type: ignore # noqa pass PARAMS.append(UvloopWorker)