Skip to content

Commit

Permalink
Improve mypy typing for azure core (#10653)
Browse files Browse the repository at this point in the history
* first commit

* more changes

* few changes

* lint

* comments

* more changes

* fix test

* lint

* mypy

* comments

* changes

* async polling method

* async
  • Loading branch information
rakshith91 authored Apr 28, 2020
1 parent 102339a commit dbcd00f
Show file tree
Hide file tree
Showing 21 changed files with 87 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Dict, Optional, Union, Callable, Sequence
from typing import Any, Mapping, Dict, Optional, Union, Callable, Sequence

from azure.core.pipeline.transport import HttpRequest, HttpResponse
AttributeValue = Union[
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
6 changes: 4 additions & 2 deletions sdk/core/azure-core/azure/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def raise_with_traceback(exception, *args, **kwargs):
message = kwargs.pop("message", "")
exc_type, exc_value, exc_traceback = sys.exc_info()
# If not called inside a "except", exc_type will be None. Assume it will not happen
exc_msg = "{}, {}: {}".format(message, exc_type.__name__, exc_value) # type: ignore
if exc_type is None:
raise ValueError("raise_with_traceback can only be used in except clauses")
exc_msg = "{}, {}: {}".format(message, exc_type.__name__, exc_value)
error = exception(exc_msg, *args, **kwargs)
try:
raise error.with_traceback(exc_traceback)
Expand Down Expand Up @@ -204,7 +206,7 @@ def __init__(self, message, *args, **kwargs):
self.exc_type = (
self.exc_type.__name__ if self.exc_type else type(self.inner_exception)
)
self.exc_msg = "{}, {}: {}".format(message, self.exc_type, self.exc_value) # type: ignore
self.exc_msg = "{}, {}: {}".format(message, self.exc_type, self.exc_value)
self.message = str(message)
super(AzureError, self).__init__(self.message, *args)

Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
try:
from contextlib import ( # pylint: disable=unused-import
AbstractContextManager,
) # type: ignore
)
except ImportError: # Python <= 3.5

class AbstractContextManager(object): # type: ignore
Expand Down
16 changes: 12 additions & 4 deletions sdk/core/azure-core/azure/core/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Pipeline(AbstractContextManager, Generic[HTTPRequestType, HTTPResponseType
def __init__(self, transport, policies=None):
# type: (HttpTransportType, PoliciesType) -> None
self._impl_policies = [] # type: List[HTTPPolicy]
self._transport = transport # type: ignore
self._transport = transport

for policy in policies or []:
if isinstance(policy, SansIOHTTPPolicy):
Expand All @@ -154,7 +154,7 @@ def _prepare_multipart_mixed_request(request):
Does nothing if "set_multipart_mixed" was never called.
"""
multipart_mixed_info = request.multipart_mixed_info # type: ignore
multipart_mixed_info = request.multipart_mixed_info # type: ignore
if not multipart_mixed_info:
return

Expand All @@ -177,6 +177,15 @@ def prepare_requests(req):
_ for _ in executor.map(prepare_requests, requests)
]

def _prepare_multipart(self, request):
# type: (HTTPRequestType) -> None
# This code is fine as long as HTTPRequestType is actually
# azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here
# since we didn't see (yet) pipeline usage where it's not this actual instance
# class used
self._prepare_multipart_mixed_request(request)
request.prepare_multipart_body() # type: ignore

def run(self, request, **kwargs):
# type: (HTTPRequestType, Any) -> PipelineResponse
"""Runs the HTTP Request through the chained policies.
Expand All @@ -186,8 +195,7 @@ def run(self, request, **kwargs):
:return: The PipelineResponse object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
self._prepare_multipart_mixed_request(request)
request.prepare_multipart_body() # type: ignore
self._prepare_multipart(request)
context = PipelineContext(self._transport, **kwargs)
pipeline_request = PipelineRequest(
request, context
Expand Down
20 changes: 13 additions & 7 deletions sdk/core/azure-core/azure/core/pipeline/_base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
AsyncPoliciesType = List[Union[AsyncHTTPPolicy, SansIOHTTPPolicy]]

try:
from contextlib import AbstractAsyncContextManager # type: ignore
from contextlib import AbstractAsyncContextManager
except ImportError: # Python <= 3.7

class AbstractAsyncContextManager(object): # type: ignore
Expand Down Expand Up @@ -160,13 +160,12 @@ async def __aenter__(self) -> "AsyncPipeline":
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
await self._transport.__aexit__(*exc_details)

async def _prepare_multipart_mixed_request(self, request):
# type: (HTTPRequestType) -> None
async def _prepare_multipart_mixed_request(self, request: HTTPRequestType) -> None:
"""Will execute the multipart policies.
Does nothing if "set_multipart_mixed" was never called.
"""
multipart_mixed_info = request.multipart_mixed_info # type: ignore
multipart_mixed_info = request.multipart_mixed_info # type: ignore
if not multipart_mixed_info:
return

Expand All @@ -186,6 +185,14 @@ async def prepare_requests(req):

await asyncio.gather(*[prepare_requests(req) for req in requests])

async def _prepare_multipart(self, request: HTTPRequestType) -> None:
# This code is fine as long as HTTPRequestType is actually
# azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here
# since we didn't see (yet) pipeline usage where it's not this actual instance
# class used
await self._prepare_multipart_mixed_request(request)
request.prepare_multipart_body() # type: ignore

async def run(self, request: HTTPRequestType, **kwargs: Any):
"""Runs the HTTP Request through the chained policies.
Expand All @@ -194,13 +201,12 @@ async def run(self, request: HTTPRequestType, **kwargs: Any):
:return: The PipelineResponse object.
:rtype: ~azure.core.pipeline.PipelineResponse
"""
await self._prepare_multipart_mixed_request(request)
request.prepare_multipart_body() # type: ignore
await self._prepare_multipart(request)
context = PipelineContext(self._transport, **kwargs)
pipeline_request = PipelineRequest(request, context)
first_node = (
self._impl_policies[0]
if self._impl_policies
else _AsyncTransportRunner(self._transport)
)
return await first_node.send(pipeline_request) # type: ignore
return await first_node.send(pipeline_request)
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def on_request(self, request):

if self._need_new_token:
self._token = self._credential.get_token(*self._scopes)
self._update_headers(request.http_request.headers, self._token.token) # type: ignore
self._update_headers(request.http_request.headers, self._token.token)


class AzureKeyCredentialPolicy(SansIOHTTPPolicy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ async def on_request(self, request: PipelineRequest):
with self._lock:
if self._need_new_token:
self._token = await self._credential.get_token(*self._scopes) # type: ignore
self._update_headers(request.http_request.headers, self._token.token) # type: ignore
self._update_headers(request.http_request.headers, self._token.token)
4 changes: 2 additions & 2 deletions sdk/core/azure-core/azure/core/pipeline/policies/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

_LOGGER = logging.getLogger(__name__)

class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: ignore
class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]):
"""An HTTP policy ABC.
Use with a synchronous pipeline.
Expand All @@ -61,7 +61,7 @@ class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]): # type: igno
"""

def __init__(self):
self.next = None
self.next = None # type: Union[HTTPPolicy, HttpTransport]

@abc.abstractmethod
def send(self, request):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,27 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-c
self._request_callback = kwargs.get('raw_request_hook')
self._response_callback = kwargs.get('raw_response_hook')

def on_request(self, request): # type: ignore # pylint: disable=arguments-differ
def on_request(self, request): # pylint: disable=arguments-differ
# type: (PipelineRequest) -> None
"""This is executed before sending the request to the next policy.
:param request: The PipelineRequest object.
:type request: ~azure.core.pipeline.PipelineRequest
"""
request_callback = request.context.options.pop('raw_request_hook', None) # type: ignore
request_callback = request.context.options.pop('raw_request_hook', None)
if request_callback:
request.context["raw_request_hook"] = request_callback
request_callback(request)
elif self._request_callback:
self._request_callback(request)

response_callback = request.context.options.pop('raw_response_hook', None) # type: ignore
response_callback = request.context.options.pop('raw_response_hook', None)
if response_callback:
request.context["raw_response_hook"] = response_callback



def on_response(self, request, response): # type: ignore # pylint: disable=arguments-differ
def on_response(self, request, response): # pylint: disable=arguments-differ
# type: (PipelineRequest, PipelineResponse) -> None
"""This is executed after the request comes back from the policy.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ._redirect import RedirectPolicy


class AsyncRedirectPolicy(RedirectPolicy, AsyncHTTPPolicy): # type: ignore
class AsyncRedirectPolicy(RedirectPolicy, AsyncHTTPPolicy):
"""An async redirect policy.
An async redirect policy in the pipeline can be configured directly or per operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@



class AsyncRetryPolicy(RetryPolicy, AsyncHTTPPolicy): # type: ignore
class AsyncRetryPolicy(RetryPolicy, AsyncHTTPPolicy):
"""Async flavor of the retry policy.
The async retry policy in the pipeline can be configured directly, or tweaked on a per-call basis.
Expand Down
10 changes: 5 additions & 5 deletions sdk/core/azure-core/azure/core/pipeline/policies/_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def on_request(self, request):
elif self._request_id is None:
return
elif self._request_id is not _Unset:
request_id = self._request_id # type: ignore
request_id = self._request_id
elif self._auto_request_id:
request_id = str(uuid.uuid1()) # type: ignore
request_id = str(uuid.uuid1())
if request_id is not unset:
header = {"x-ms-client-request-id": request_id}
request.http_request.headers.update(header)
Expand Down Expand Up @@ -286,10 +286,10 @@ def on_request(self, request):
_LOGGER.debug("Request body:")

# We don't want to log the binary data of a file upload.
if isinstance(http_request.body, types.GeneratorType): # type: ignore
if isinstance(http_request.body, types.GeneratorType):
_LOGGER.debug("File upload")
else:
_LOGGER.debug(str(http_request.body)) # type: ignore
_LOGGER.debug(str(http_request.body))
except Exception as err: # pylint: disable=broad-except
_LOGGER.debug("Failed to log request: %r", err)

Expand Down Expand Up @@ -498,7 +498,7 @@ def deserialize_from_text(
try:
if isinstance(data, unicode): # type: ignore
# If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string
data_as_str = data_as_str.encode(encoding="utf-8") # type: ignore
data_as_str = cast(str, data_as_str.encode(encoding="utf-8"))
except NameError:
pass
return ET.fromstring(data_as_str)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import abc
from collections.abc import AsyncIterator

from typing import AsyncIterator as AsyncIteratorType, Generic, TypeVar
from typing import AsyncIterator as AsyncIteratorType, TypeVar, Generic
from ._base import (
_HttpResponseBase,
_HttpClientTransportResponse,
Expand All @@ -52,12 +52,10 @@ async def __aexit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None


AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
HTTPResponseType = TypeVar("HTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")


class _ResponseStopIteration(Exception):
pass

Expand Down Expand Up @@ -163,7 +161,7 @@ class AsyncHttpClientTransportResponse(_HttpClientTransportResponse, AsyncHttpRe
class AsyncHttpTransport(
AbstractAsyncContextManager,
abc.ABC,
Generic[HTTPRequestType, AsyncHTTPResponseType],
Generic[HTTPRequestType, AsyncHTTPResponseType]
):
"""An http sender ABC.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ async def _retrieve_request_data(self, request):
# That's not ideal, but a list is our only choice. Memory not optimal here,
# but providing an async generator to a requests based transport is not optimal too
new_data = []
async for part in request.data: # type: ignore
async for part in request.data:
new_data.append(part)
data_to_send = iter(new_data)
else:
data_to_send = request.data # type: ignore
data_to_send = request.data
return data_to_send

async def __aenter__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _get_running_loop():


#pylint: disable=too-many-ancestors
class AsyncioRequestsTransport(RequestsAsyncTransportBase): # type: ignore
class AsyncioRequestsTransport(RequestsAsyncTransportBase):
"""Identical implementation as the synchronous RequestsTransport wrapped in a class with
asynchronous methods. Uses the built-in asyncio event loop.
Expand Down Expand Up @@ -103,7 +103,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse:
response = await loop.run_in_executor(
None,
functools.partial(
self.session.request, # type: ignore
self.session.request,
request.method,
request.url,
headers=request.headers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import functools
import logging
from typing import Any, Callable, Union, Optional, AsyncIterator as AsyncIteratorType
import trio # type: ignore
import urllib3 # type: ignore
import trio
import urllib3

import requests

Expand Down Expand Up @@ -128,7 +128,7 @@ class TrioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse
def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: # type: ignore
"""Generator for streaming response data.
"""
return TrioStreamDownloadGenerator(pipeline, self) # type: ignore
return TrioStreamDownloadGenerator(pipeline, self)


class TrioRequestsTransport(RequestsAsyncTransportBase): # type: ignore
Expand Down Expand Up @@ -174,7 +174,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse:
try:
response = await trio.to_thread.run_sync(
functools.partial(
self.session.request, # type: ignore
self.session.request,
request.method,
request.url,
headers=request.headers,
Expand All @@ -189,7 +189,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncHttpResponse:
except AttributeError: # trio < 0.12.1
response = await trio.run_sync_in_worker_thread( # pylint: disable=no-member
functools.partial(
self.session.request, # type: ignore
self.session.request,
request.method,
request.url,
headers=request.headers,
Expand Down
17 changes: 11 additions & 6 deletions sdk/core/azure-core/azure/core/polling/_async_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,29 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import Generic, TypeVar, Any
from ._poller import NoPolling as _NoPolling

class AsyncPollingMethod(object):

PollingReturnType = TypeVar("PollingReturnType")


class AsyncPollingMethod(Generic[PollingReturnType]):
"""ABC class for polling method.
"""
def initialize(self, client, initial_response, deserialization_callback):
def initialize(self, client: Any, initial_response: Any, deserialization_callback: Any) -> None:
raise NotImplementedError("This method needs to be implemented")

async def run(self):
async def run(self) -> None:
raise NotImplementedError("This method needs to be implemented")

def status(self):
def status(self) -> str:
raise NotImplementedError("This method needs to be implemented")

def finished(self):
def finished(self) -> bool:
raise NotImplementedError("This method needs to be implemented")

def resource(self):
def resource(self) -> PollingReturnType:
raise NotImplementedError("This method needs to be implemented")


Expand Down
Loading

0 comments on commit dbcd00f

Please sign in to comment.