Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix @reify type hints #4736

Merged
merged 4 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/4736.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Improve typing annotations for ``web.Request``, ``aiohttp.ClientResponse`` and
``multipart`` module.
25 changes: 13 additions & 12 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,25 +522,25 @@ async def _request(
resp.release()

try:
r_url = URL(
parsed_url = URL(
r_url, encoded=not self._requote_redirect_url)

except ValueError:
raise InvalidURL(r_url)

scheme = r_url.scheme
scheme = parsed_url.scheme
if scheme not in ('http', 'https', ''):
resp.close()
raise ValueError(
'Can redirect only to http or https')
elif not scheme:
r_url = url.join(r_url)
parsed_url = url.join(parsed_url)

if url.origin() != r_url.origin():
if url.origin() != parsed_url.origin():
auth = None
headers.pop(hdrs.AUTHORIZATION, None)

url = r_url
url = parsed_url
params = None
resp.release()
continue
Expand Down Expand Up @@ -737,10 +737,10 @@ async def _ws_connect(
headers=resp.headers)

# key calculation
key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
match = base64.b64encode(
hashlib.sha1(sec_key + WS_KEY).digest()).decode()
if key != match:
if r_key != match:
raise WSServerHandshakeError(
resp.request_info,
resp.history,
Expand Down Expand Up @@ -780,15 +780,16 @@ async def _ws_connect(

conn = resp.connection
assert conn is not None
proto = conn.protocol
assert proto is not None
conn_proto = conn.protocol
assert conn_proto is not None
transport = conn.transport
assert transport is not None
reader = FlowControlDataQueue(
proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa
proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
conn_proto, limit=2 ** 16, loop=self._loop) # type: FlowControlDataQueue[WSMessage] # noqa
conn_proto.set_parser(
WebSocketReader(reader, max_msg_size), reader)
writer = WebSocketWriter(
proto, transport, use_mask=True,
conn_proto, transport, use_mask=True,
compress=compress, notakeover=notakeover)
except BaseException:
resp.close()
Expand Down
5 changes: 2 additions & 3 deletions aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union

from .typedefs import _CIMultiDict
from .typedefs import LooseHeaders

try:
import ssl
Expand All @@ -22,7 +22,6 @@
else:
RequestInfo = ClientResponse = ConnectionKey = None


__all__ = (
'ClientError',

Expand Down Expand Up @@ -55,7 +54,7 @@ def __init__(self, request_info: RequestInfo,
history: Tuple[ClientResponse, ...], *,
status: Optional[int]=None,
message: str='',
headers: Optional[_CIMultiDict]=None) -> None:
headers: Optional[LooseHeaders]=None) -> None:
self.request_info = request_info
if status is not None:
self.status = status
Expand Down
19 changes: 15 additions & 4 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Callable,
Dict,
Generator,
Generic,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -66,6 +67,11 @@
except ImportError:
from typing_extensions import ContextManager

if PY_38:
from typing import Protocol
else:
from typing_extensions import Protocol # type: ignore


def all_tasks(
loop: Optional[asyncio.AbstractEventLoop] = None
Expand All @@ -79,6 +85,7 @@ def all_tasks(


_T = TypeVar('_T')
_S = TypeVar('_S')


sentinel = object() # type: Any
Expand Down Expand Up @@ -382,7 +389,11 @@ def is_expected_content_type(response_content_type: str,
return expected_content_type in response_content_type


class reify:
class _TSelf(Protocol):
_cache: Dict[str, Any]


class reify(Generic[_T]):
"""Use as a class method decorator. It operates almost exactly like
the Python `@property` decorator, but it puts the result of the
method it decorates into the instance dict after the first call,
Expand All @@ -391,12 +402,12 @@ class reify:

"""

def __init__(self, wrapped: Callable[..., Any]) -> None:
def __init__(self, wrapped: Callable[..., _T]) -> None:
self.wrapped = wrapped
self.__doc__ = wrapped.__doc__
self.name = wrapped.__name__

def __get__(self, inst: Any, owner: Any) -> Any:
def __get__(self, inst: _TSelf, owner: Optional[Type[Any]] = None) -> _T:
try:
try:
return inst._cache[self.name]
Expand All @@ -409,7 +420,7 @@ def __get__(self, inst: Any, owner: Any) -> Any:
return self
raise

def __set__(self, inst: Any, value: Any) -> None:
def __set__(self, inst: _TSelf, value: _T) -> None:
raise AttributeError("reified property is read-only")


Expand Down
3 changes: 2 additions & 1 deletion aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
_WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?')


def ws_ext_parse(extstr: str, isserver: bool=False) -> Tuple[int, bool]:
def ws_ext_parse(extstr: Optional[str],
isserver: bool=False) -> Tuple[int, bool]:
if not extstr:
return 0, False

Expand Down
6 changes: 4 additions & 2 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
set_result,
)
from .http_parser import RawRequestMessage
from .http_writer import HttpVersion
from .multipart import BodyPartReader, MultipartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
Expand Down Expand Up @@ -343,7 +344,7 @@ def method(self) -> str:
return self._method

@reify
def version(self) -> Tuple[int, int]:
def version(self) -> HttpVersion:
"""Read only property for getting HTTP version of request.

Returns aiohttp.protocol.HttpVersion instance.
Expand Down Expand Up @@ -434,7 +435,7 @@ def raw_headers(self) -> RawHeaders:
return self._message.raw_headers

@staticmethod
def _http_date(_date_str: str) -> Optional[datetime.datetime]:
def _http_date(_date_str: Optional[str]) -> Optional[datetime.datetime]:
"""Process a date string, return a datetime object
"""
if _date_str is not None:
Expand Down Expand Up @@ -618,6 +619,7 @@ async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]':
field_ct = field.headers.get(hdrs.CONTENT_TYPE)

if isinstance(field, BodyPartReader):
assert field.name is not None
if field.filename and field_ct:
# store file in temp file
tmp = tempfile.TemporaryFile()
Expand Down
6 changes: 4 additions & 2 deletions aiohttp/web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def _default_expect_handler(request: Request) -> None:
Just send "100 Continue" to client.
raise HTTPExpectationFailed if value of header is not "100-continue"
"""
expect = request.headers.get(hdrs.EXPECT)
expect = request.headers.get(hdrs.EXPECT, "")
if request.version == HttpVersion11:
if expect.lower() == "100-continue":
await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")
Expand Down Expand Up @@ -749,7 +749,9 @@ def validation(self, domain: str) -> str:

async def match(self, request: Request) -> bool:
host = request.headers.get(hdrs.HOST)
return host and self.match_domain(host)
if not host:
return False
return self.match_domain(host)

def match_domain(self, host: str) -> bool:
return host.lower() == self._domain
Expand Down