Skip to content

Commit

Permalink
Type annotations (aio-libs#3049)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov authored Jun 18, 2018
1 parent 17f5a3d commit 7cf5785
Show file tree
Hide file tree
Showing 33 changed files with 328 additions and 234 deletions.
1 change: 1 addition & 0 deletions CHANGES/3049.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ graft docs
graft examples
graft tests
recursive-include vendor *
global-include aiohttp *.pyi
global-exclude *.pyc
global-exclude *.pyd
global-exclude *.so
Expand Down
10 changes: 6 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ flake: .flake
check_changes:
@./tools/check_changes.py

.develop: .install-deps $(shell find aiohttp -type f) .flake check_changes
mypy: .flake
if python -c "import sys; sys.exit(sys.implementation.name!='cpython')"; then \
mypy aiohttp tests; \
fi

.develop: .install-deps $(shell find aiohttp -type f) .flake check_changes mypy
@pip install -e .
@touch .develop

Expand Down Expand Up @@ -112,7 +117,4 @@ install:
@pip install -U pip
@pip install -Ur requirements/dev.txt

mypy:
mypy aiohttp tests

.PHONY: all build flake test vtest cov clean doc
3 changes: 2 additions & 1 deletion aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import hdrs # noqa
from .client import * # noqa
from .client import ClientSession, ServerFingerprintMismatch # noqa
from .cookiejar import * # noqa
from .formdata import * # noqa
from .helpers import * # noqa
Expand All @@ -21,7 +22,7 @@
from .worker import GunicornWebWorker, GunicornUVLoopWebWorker # noqa
workers = ('GunicornWebWorker', 'GunicornUVLoopWebWorker')
except ImportError: # pragma: no cover
workers = ()
workers = () # type: ignore


__all__ = (client.__all__ + # noqa
Expand Down
8 changes: 8 additions & 0 deletions aiohttp/_helpers.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Any

class reify:
def __init__(self, wrapped: Any) -> None: ...

def __get__(self, inst: Any, owner: Any) -> Any: ...

def __set__(self, inst: Any, value: Any) -> None: ...
4 changes: 2 additions & 2 deletions aiohttp/_http_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ cdef class HttpParser:
return messages, False, b''


cdef class HttpRequestParserC(HttpParser):
cdef class HttpRequestParser(HttpParser):

def __init__(self, protocol, loop, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
Expand All @@ -335,7 +335,7 @@ cdef class HttpRequestParserC(HttpParser):
self._buf.clear()


cdef class HttpResponseParserC(HttpParser):
cdef class HttpResponseParser(HttpParser):

def __init__(self, protocol, loop, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
request_class=ClientRequest, response_class=ClientResponse,
ws_response_class=ClientWebSocketResponse,
version=http.HttpVersion11,
cookie_jar=None, connector_owner=True, raise_for_status=False,
cookie_jar=None, connector_owner=True,
raise_for_status=False,
read_timeout=sentinel, conn_timeout=None,
timeout=sentinel,
auto_decompress=True, trust_env=False,
Expand Down
14 changes: 7 additions & 7 deletions aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
try:
import ssl
except ImportError: # pragma: no cover
ssl = None
ssl = None # type: ignore


__all__ = (
Expand Down Expand Up @@ -203,24 +203,24 @@ class ClientSSLError(ClientConnectorError):


if ssl is not None:
certificate_errors = (ssl.CertificateError,)
certificate_errors_bases = (ClientSSLError, ssl.CertificateError,)
cert_errors = (ssl.CertificateError,)
cert_errors_bases = (ClientSSLError, ssl.CertificateError,)

ssl_errors = (ssl.SSLError,)
ssl_error_bases = (ClientSSLError, ssl.SSLError)
else: # pragma: no cover
certificate_errors = tuple()
certificate_errors_bases = (ClientSSLError, ValueError,)
cert_errors = tuple()
cert_errors_bases = (ClientSSLError, ValueError,)

ssl_errors = tuple()
ssl_error_bases = (ClientSSLError,)


class ClientConnectorSSLError(*ssl_error_bases):
class ClientConnectorSSLError(*ssl_error_bases): # type: ignore
"""Response ssl error."""


class ClientConnectorCertificateError(*certificate_errors_bases):
class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore
"""Response certificate error."""

def __init__(self, connection_key, certificate_error):
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
try:
import ssl
except ImportError: # pragma: no cover
ssl = None
ssl = None # type: ignore

try:
import cchardet as chardet
Expand Down
6 changes: 3 additions & 3 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ClientConnectorError, ClientConnectorSSLError,
ClientHttpProxyError,
ClientProxyConnectionError,
ServerFingerprintMismatch, certificate_errors,
ServerFingerprintMismatch, cert_errors,
ssl_errors)
from .client_proto import ResponseHandler
from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
Expand All @@ -30,7 +30,7 @@
try:
import ssl
except ImportError: # pragma: no cover
ssl = None
ssl = None # type: ignore


__all__ = ('BaseConnector', 'TCPConnector', 'UnixConnector')
Expand Down Expand Up @@ -820,7 +820,7 @@ async def _wrap_create_connection(self, *args,
try:
with CeilTimeout(timeout.sock_connect):
return await self._loop.create_connection(*args, **kwargs)
except certificate_errors as exc:
except cert_errors as exc:
raise ClientConnectorCertificateError(
req.connection_key, exc) from exc
except ssl_errors as exc:
Expand Down
19 changes: 9 additions & 10 deletions aiohttp/frozenlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,8 @@
from .helpers import NO_EXTENSIONS


if not NO_EXTENSIONS:
try:
from aiohttp._frozenlist import FrozenList
except ImportError: # pragma: no cover
FrozenList = None


@total_ordering
class PyFrozenList(MutableSequence):
class FrozenList(MutableSequence):

__slots__ = ('_frozen', '_items')

Expand Down Expand Up @@ -69,5 +62,11 @@ def __repr__(self):
self._items)


if NO_EXTENSIONS or FrozenList is None:
FrozenList = PyFrozenList
PyFrozenList = FrozenList

try:
from aiohttp._frozenlist import FrozenList as CFrozenList # type: ignore
if not NO_EXTENSIONS:
FrozenList = CFrozenList # type: ignore
except ImportError: # pragma: no cover
pass
51 changes: 28 additions & 23 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
import time
import weakref
from collections import namedtuple
from collections.abc import Mapping
from collections.abc import Mapping as ABCMapping
from contextlib import suppress
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple # noqa
from urllib.parse import quote
from urllib.request import getproxies

Expand All @@ -41,8 +42,8 @@
idna_ssl.patch_match_hostname()


sentinel = object()
NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS'))
sentinel = object() # type: Any
NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')) # type: bool

# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr
# for compatibility with older versions
Expand All @@ -59,24 +60,26 @@


coroutines = asyncio.coroutines
old_debug = coroutines._DEBUG
old_debug = coroutines._DEBUG # type: ignore

# prevent "coroutine noop was never awaited" warning.
coroutines._DEBUG = False
coroutines._DEBUG = False # type: ignore


@asyncio.coroutine
def noop(*args, **kwargs):
return


coroutines._DEBUG = old_debug
coroutines._DEBUG = old_debug # type: ignore


class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])):
"""Http basic authentication helper."""

def __new__(cls, login, password='', encoding='latin1'):
def __new__(cls, login: str,
password: str='',
encoding: str='latin1') -> 'BasicAuth':
if login is None:
raise ValueError('None is not allowed as login value')

Expand All @@ -90,7 +93,7 @@ def __new__(cls, login, password='', encoding='latin1'):
return super().__new__(cls, login, password, encoding)

@classmethod
def decode(cls, auth_header, encoding='latin1'):
def decode(cls, auth_header: str, encoding: str='latin1') -> 'BasicAuth':
"""Create a BasicAuth object from an Authorization HTTP header."""
split = auth_header.strip().split(' ')
if len(split) == 2:
Expand All @@ -110,21 +113,22 @@ def decode(cls, auth_header, encoding='latin1'):
return cls(username, password, encoding=encoding)

@classmethod
def from_url(cls, url, *, encoding='latin1'):
def from_url(cls, url: URL,
*, encoding: str='latin1') -> Optional['BasicAuth']:
"""Create BasicAuth from url."""
if not isinstance(url, URL):
raise TypeError("url should be yarl.URL instance")
if url.user is None:
return None
return cls(url.user, url.password or '', encoding=encoding)

def encode(self):
def encode(self) -> str:
"""Encode credentials."""
creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding)
return 'Basic %s' % base64.b64encode(creds).decode(self.encoding)


def strip_auth_from_url(url):
def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
auth = BasicAuth.from_url(url)
if auth is None:
return url, None
Expand Down Expand Up @@ -293,6 +297,9 @@ def content_disposition_header(disptype, quote_fields=True, **params):
return value


KeyMethod = namedtuple('KeyMethod', 'key method')


class AccessLogger(AbstractAccessLogger):
"""Helper object to log access.
Expand Down Expand Up @@ -336,9 +343,7 @@ class AccessLogger(AbstractAccessLogger):
LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"'
FORMAT_RE = re.compile(r'%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)')
CLEANUP_RE = re.compile(r'(%[^s])')
_FORMAT_CACHE = {}

KeyMethod = namedtuple('KeyMethod', 'key method')
_FORMAT_CACHE = {} # type: Dict[str, Tuple[str, List[KeyMethod]]]

def __init__(self, logger, log_format=LOG_FORMAT):
"""Initialise the logger.
Expand Down Expand Up @@ -390,7 +395,7 @@ def compile_format(self, log_format):
m = getattr(AccessLogger, '_format_%s' % atom[2])
m = functools.partial(m, atom[1])

methods.append(self.KeyMethod(format_key, m))
methods.append(KeyMethod(format_key, m))

log_format = self.FORMAT_RE.sub(r'%s', log_format)
log_format = self.CLEANUP_RE.sub(r'%\1', log_format)
Expand Down Expand Up @@ -515,7 +520,7 @@ def __set__(self, inst, value):
try:
from ._helpers import reify as reify_c
if not NO_EXTENSIONS:
reify = reify_c
reify = reify_c # type: ignore
except ImportError:
pass

Expand Down Expand Up @@ -716,25 +721,25 @@ def _parse_content_type(self, raw):
self._content_type, self._content_dict = cgi.parse_header(raw)

@property
def content_type(self, *, _CONTENT_TYPE=hdrs.CONTENT_TYPE):
def content_type(self):
"""The value of content part for Content-Type HTTP header."""
raw = self._headers.get(_CONTENT_TYPE)
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_type

@property
def charset(self, *, _CONTENT_TYPE=hdrs.CONTENT_TYPE):
def charset(self):
"""The value of charset part for Content-Type HTTP header."""
raw = self._headers.get(_CONTENT_TYPE)
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_dict.get('charset')

@property
def content_length(self, *, _CONTENT_LENGTH=hdrs.CONTENT_LENGTH):
def content_length(self):
"""The value of Content-Length HTTP header."""
content_length = self._headers.get(_CONTENT_LENGTH)
content_length = self._headers.get(hdrs.CONTENT_LENGTH)

if content_length:
return int(content_length)
Expand All @@ -750,7 +755,7 @@ def set_exception(fut, exc):
fut.set_exception(exc)


class ChainMapProxy(Mapping):
class ChainMapProxy(ABCMapping):
__slots__ = ('_maps',)

def __init__(self, maps):
Expand Down
Loading

0 comments on commit 7cf5785

Please sign in to comment.