From fea5244a41a8923fdb751432b570b2bb4802f8e3 Mon Sep 17 00:00:00 2001 From: Albert Tugushev Date: Thu, 15 Oct 2020 22:40:48 +0700 Subject: [PATCH] Fix @reify type hints (#4736) * Fix @reify type hints * Fix issues with type hints * Add a changelog Co-authored-by: Andrew Svetlov --- CHANGES/4736.bugfix | 2 ++ aiohttp/client.py | 25 +++++++++++++------------ aiohttp/client_exceptions.py | 5 ++--- aiohttp/helpers.py | 19 +++++++++++++++---- aiohttp/http_websocket.py | 3 ++- aiohttp/web_request.py | 6 ++++-- aiohttp/web_urldispatcher.py | 6 ++++-- 7 files changed, 42 insertions(+), 24 deletions(-) create mode 100644 CHANGES/4736.bugfix diff --git a/CHANGES/4736.bugfix b/CHANGES/4736.bugfix new file mode 100644 index 00000000000..8c562571d6b --- /dev/null +++ b/CHANGES/4736.bugfix @@ -0,0 +1,2 @@ +Improve typing annotations for ``web.Request``, ``aiohttp.ClientResponse`` and +``multipart`` module. diff --git a/aiohttp/client.py b/aiohttp/client.py index 2ed03aec021..0e4bd86bb39 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -559,25 +559,25 @@ async def _request( resp.release() try: - r_url = URL( + parsed_url = URL( r_url, encoded=not self._requote_redirect_url) except ValueError: raise InvalidURL(r_url) - scheme = r_url.scheme + scheme = parsed_url.scheme if scheme not in ('http', 'https', ''): resp.close() raise ValueError( 'Can redirect only to http or https') elif not scheme: - r_url = url.join(r_url) + parsed_url = url.join(parsed_url) - if url.origin() != r_url.origin(): + if url.origin() != parsed_url.origin(): auth = None headers.pop(hdrs.AUTHORIZATION, None) - url = r_url + url = parsed_url params = None resp.release() continue @@ -757,10 +757,10 @@ async def _ws_connect( headers=resp.headers) # key calculation - key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '') + r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '') match = base64.b64encode( hashlib.sha1(sec_key + WS_KEY).digest()).decode() - if key != match: + if r_key != match: raise WSServerHandshakeError( resp.request_info, resp.history, @@ -800,15 +800,16 @@ async def _ws_connect( conn = resp.connection assert conn is not None - proto = conn.protocol - assert proto is not None + conn_proto = conn.protocol + assert conn_proto is not None transport = conn.transport assert transport is not None reader = FlowControlDataQueue( - proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa - proto.set_parser(WebSocketReader(reader, max_msg_size), reader) + conn_proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa + conn_proto.set_parser( + WebSocketReader(reader, max_msg_size), reader) writer = WebSocketWriter( - proto, transport, use_mask=True, + conn_proto, transport, use_mask=True, compress=compress, notakeover=notakeover) except BaseException: resp.close() diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 55e9501cdc4..eb53eb8443d 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -4,7 +4,7 @@ import warnings from typing import TYPE_CHECKING, Any, Optional, Tuple, Union -from .typedefs import _CIMultiDict +from .typedefs import LooseHeaders try: import ssl @@ -23,7 +23,6 @@ else: RequestInfo = ClientResponse = ConnectionKey = None - __all__ = ( 'ClientError', @@ -57,7 +56,7 @@ def __init__(self, request_info: RequestInfo, code: Optional[int]=None, status: Optional[int]=None, message: str='', - headers: Optional[_CIMultiDict]=None) -> None: + headers: Optional[LooseHeaders]=None) -> None: self.request_info = request_info if code is not None: if status is not None: diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 87727d81f06..d13240f0805 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -25,6 +25,7 @@ Callable, Dict, Generator, + Generic, Iterable, Iterator, List, @@ -65,6 +66,11 @@ except ImportError: from typing_extensions import ContextManager +if PY_38: + from typing import Protocol +else: + from typing_extensions import Protocol # type: ignore + def all_tasks( loop: Optional[asyncio.AbstractEventLoop] = None @@ -78,6 +84,7 @@ def all_tasks( _T = TypeVar('_T') +_S = TypeVar('_S') sentinel = object() # type: Any @@ -360,7 +367,11 @@ def content_disposition_header(disptype: str, return value -class reify: +class _TSelf(Protocol): + _cache: Dict[str, Any] + + +class reify(Generic[_T]): """Use as a class method decorator. It operates almost exactly like the Python `@property` decorator, but it puts the result of the method it decorates into the instance dict after the first call, @@ -369,12 +380,12 @@ class reify: """ - def __init__(self, wrapped: Callable[..., Any]) -> None: + def __init__(self, wrapped: Callable[..., _T]) -> None: self.wrapped = wrapped self.__doc__ = wrapped.__doc__ self.name = wrapped.__name__ - def __get__(self, inst: Any, owner: Any) -> Any: + def __get__(self, inst: _TSelf, owner: Optional[Type[Any]] = None) -> _T: try: try: return inst._cache[self.name] @@ -387,7 +398,7 @@ def __get__(self, inst: Any, owner: Any) -> Any: return self raise - def __set__(self, inst: Any, value: Any) -> None: + def __set__(self, inst: _TSelf, value: _T) -> None: raise AttributeError("reified property is read-only") diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index b8bf826d223..8877fb6aa44 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -162,7 +162,8 @@ def _websocket_mask_python(mask: bytes, data: bytearray) -> None: _WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?') -def ws_ext_parse(extstr: str, isserver: bool=False) -> Tuple[int, bool]: +def ws_ext_parse(extstr: Optional[str], + isserver: bool=False) -> Tuple[int, bool]: if not extstr: return 0, False diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index a931bb84510..2dad0f2faa6 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -32,6 +32,7 @@ from .abc import AbstractStreamWriter from .helpers import DEBUG, ChainMapProxy, HeadersMixin, reify, sentinel from .http_parser import RawRequestMessage +from .http_writer import HttpVersion from .multipart import BodyPartReader, MultipartReader from .streams import EmptyStreamReader, StreamReader from .typedefs import ( @@ -342,7 +343,7 @@ def method(self) -> str: return self._method @reify - def version(self) -> Tuple[int, int]: + def version(self) -> HttpVersion: """Read only property for getting HTTP version of request. Returns aiohttp.protocol.HttpVersion instance. @@ -433,7 +434,7 @@ def raw_headers(self) -> RawHeaders: return self._message.raw_headers @staticmethod - def _http_date(_date_str: str) -> Optional[datetime.datetime]: + def _http_date(_date_str: Optional[str]) -> Optional[datetime.datetime]: """Process a date string, return a datetime object """ if _date_str is not None: @@ -614,6 +615,7 @@ async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]': field_ct = field.headers.get(hdrs.CONTENT_TYPE) if isinstance(field, BodyPartReader): + assert field.name is not None if field.filename and field_ct: # store file in temp file tmp = tempfile.TemporaryFile() diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 70ee92751ae..499788cbde2 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -282,7 +282,7 @@ async def _default_expect_handler(request: Request) -> None: Just send "100 Continue" to client. raise HTTPExpectationFailed if value of header is not "100-continue" """ - expect = request.headers.get(hdrs.EXPECT) + expect = request.headers.get(hdrs.EXPECT, "") if request.version == HttpVersion11: if expect.lower() == "100-continue": await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") @@ -767,7 +767,9 @@ def validation(self, domain: str) -> str: async def match(self, request: Request) -> bool: host = request.headers.get(hdrs.HOST) - return host and self.match_domain(host) + if not host: + return False + return self.match_domain(host) def match_domain(self, host: str) -> bool: return host.lower() == self._domain