Skip to content

Commit

Permalink
Fix @reify type hints (#4736)
Browse files Browse the repository at this point in the history
* Fix @reify type hints

* Fix issues with type hints

* Add a changelog

Co-authored-by: Andrew Svetlov <[email protected]>
  • Loading branch information
atugushev and asvetlov authored Oct 15, 2020
1 parent 29e9c85 commit 634b361
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGES/4736.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Improve typing annotations for ``web.Request``, ``aiohttp.ClientResponse`` and
``multipart`` module.
25 changes: 13 additions & 12 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,25 +522,25 @@ async def _request(
resp.release()

try:
r_url = URL(
parsed_url = URL(
r_url, encoded=not self._requote_redirect_url)

except ValueError:
raise InvalidURL(r_url)

scheme = r_url.scheme
scheme = parsed_url.scheme
if scheme not in ('http', 'https', ''):
resp.close()
raise ValueError(
'Can redirect only to http or https')
elif not scheme:
r_url = url.join(r_url)
parsed_url = url.join(parsed_url)

if url.origin() != r_url.origin():
if url.origin() != parsed_url.origin():
auth = None
headers.pop(hdrs.AUTHORIZATION, None)

url = r_url
url = parsed_url
params = None
resp.release()
continue
Expand Down Expand Up @@ -737,10 +737,10 @@ async def _ws_connect(
headers=resp.headers)

# key calculation
key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
match = base64.b64encode(
hashlib.sha1(sec_key + WS_KEY).digest()).decode()
if key != match:
if r_key != match:
raise WSServerHandshakeError(
resp.request_info,
resp.history,
Expand Down Expand Up @@ -780,15 +780,16 @@ async def _ws_connect(

conn = resp.connection
assert conn is not None
proto = conn.protocol
assert proto is not None
conn_proto = conn.protocol
assert conn_proto is not None
transport = conn.transport
assert transport is not None
reader = FlowControlDataQueue(
proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa
proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
conn_proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa
conn_proto.set_parser(
WebSocketReader(reader, max_msg_size), reader)
writer = WebSocketWriter(
proto, transport, use_mask=True,
conn_proto, transport, use_mask=True,
compress=compress, notakeover=notakeover)
except BaseException:
resp.close()
Expand Down
5 changes: 2 additions & 3 deletions aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union

from .typedefs import _CIMultiDict
from .typedefs import LooseHeaders

try:
import ssl
Expand All @@ -22,7 +22,6 @@
else:
RequestInfo = ClientResponse = ConnectionKey = None


__all__ = (
'ClientError',

Expand Down Expand Up @@ -55,7 +54,7 @@ def __init__(self, request_info: RequestInfo,
history: Tuple[ClientResponse, ...], *,
status: Optional[int]=None,
message: str='',
headers: Optional[_CIMultiDict]=None) -> None:
headers: Optional[LooseHeaders]=None) -> None:
self.request_info = request_info
if status is not None:
self.status = status
Expand Down
19 changes: 15 additions & 4 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Callable,
Dict,
Generator,
Generic,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -66,6 +67,11 @@
except ImportError:
from typing_extensions import ContextManager

if PY_38:
from typing import Protocol
else:
from typing_extensions import Protocol # type: ignore


def all_tasks(
loop: Optional[asyncio.AbstractEventLoop] = None
Expand All @@ -79,6 +85,7 @@ def all_tasks(


_T = TypeVar('_T')
_S = TypeVar('_S')


sentinel = object() # type: Any
Expand Down Expand Up @@ -382,7 +389,11 @@ def is_expected_content_type(response_content_type: str,
return expected_content_type in response_content_type


class reify:
class _TSelf(Protocol):
_cache: Dict[str, Any]


class reify(Generic[_T]):
"""Use as a class method decorator. It operates almost exactly like
the Python `@property` decorator, but it puts the result of the
method it decorates into the instance dict after the first call,
Expand All @@ -391,12 +402,12 @@ class reify:
"""

def __init__(self, wrapped: Callable[..., Any]) -> None:
def __init__(self, wrapped: Callable[..., _T]) -> None:
self.wrapped = wrapped
self.__doc__ = wrapped.__doc__
self.name = wrapped.__name__

def __get__(self, inst: Any, owner: Any) -> Any:
def __get__(self, inst: _TSelf, owner: Optional[Type[Any]] = None) -> _T:
try:
try:
return inst._cache[self.name]
Expand All @@ -409,7 +420,7 @@ def __get__(self, inst: Any, owner: Any) -> Any:
return self
raise

def __set__(self, inst: Any, value: Any) -> None:
def __set__(self, inst: _TSelf, value: _T) -> None:
raise AttributeError("reified property is read-only")


Expand Down
3 changes: 2 additions & 1 deletion aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
_WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?')


def ws_ext_parse(extstr: str, isserver: bool=False) -> Tuple[int, bool]:
def ws_ext_parse(extstr: Optional[str],
isserver: bool=False) -> Tuple[int, bool]:
if not extstr:
return 0, False

Expand Down
6 changes: 4 additions & 2 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
set_result,
)
from .http_parser import RawRequestMessage
from .http_writer import HttpVersion
from .multipart import BodyPartReader, MultipartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
Expand Down Expand Up @@ -343,7 +344,7 @@ def method(self) -> str:
return self._method

@reify
def version(self) -> Tuple[int, int]:
def version(self) -> HttpVersion:
"""Read only property for getting HTTP version of request.
Returns aiohttp.protocol.HttpVersion instance.
Expand Down Expand Up @@ -434,7 +435,7 @@ def raw_headers(self) -> RawHeaders:
return self._message.raw_headers

@staticmethod
def _http_date(_date_str: str) -> Optional[datetime.datetime]:
def _http_date(_date_str: Optional[str]) -> Optional[datetime.datetime]:
"""Process a date string, return a datetime object
"""
if _date_str is not None:
Expand Down Expand Up @@ -618,6 +619,7 @@ async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]':
field_ct = field.headers.get(hdrs.CONTENT_TYPE)

if isinstance(field, BodyPartReader):
assert field.name is not None
if field.filename and field_ct:
# store file in temp file
tmp = tempfile.TemporaryFile()
Expand Down
6 changes: 4 additions & 2 deletions aiohttp/web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def _default_expect_handler(request: Request) -> None:
Just send "100 Continue" to client.
raise HTTPExpectationFailed if value of header is not "100-continue"
"""
expect = request.headers.get(hdrs.EXPECT)
expect = request.headers.get(hdrs.EXPECT, "")
if request.version == HttpVersion11:
if expect.lower() == "100-continue":
await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")
Expand Down Expand Up @@ -749,7 +749,9 @@ def validation(self, domain: str) -> str:

async def match(self, request: Request) -> bool:
host = request.headers.get(hdrs.HOST)
return host and self.match_domain(host)
if not host:
return False
return self.match_domain(host)

def match_domain(self, host: str) -> bool:
return host.lower() == self._domain
Expand Down

0 comments on commit 634b361

Please sign in to comment.