Skip to content

Commit

Permalink
[Tables] mypy (#19001)
Browse files Browse the repository at this point in the history
* type hints

* mypy fixes

* adding small fix

* fixed more async issues

* adding pylint changes

* more lint fixes

* caught another one

* fixes for py3.7

* removal
  • Loading branch information
seankane-msft authored Jun 2, 2021
1 parent 11dc9ea commit 816b025
Show file tree
Hide file tree
Showing 18 changed files with 171 additions and 158 deletions.
1 change: 1 addition & 0 deletions eng/tox/mypy_hard_failure_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
"azure-ai-formrecognizer",
"azure-ai-metricsadvisor",
"azure-eventgrid",
"azure-data-tables",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
# --------------------------------------------------------------------------

import logging
import sys
from typing import Union
from typing import TYPE_CHECKING

try:
from urllib.parse import urlparse
Expand All @@ -19,7 +18,7 @@
try:
from azure.core.pipeline.transport import AsyncHttpTransport
except ImportError:
AsyncHttpTransport = None
AsyncHttpTransport = None # type: ignore

try:
from yarl import URL
Expand All @@ -34,8 +33,9 @@
_wrap_exception,
)

if sys.version_info > (3, 5):
from typing import Awaitable # pylint: disable=ungrouped-imports
if TYPE_CHECKING:
from azure.core.pipeline import PipelineRequest # pylint: disable=ungrouped-imports


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,7 +112,7 @@ def _add_authorization_header(self, request, string_to_sign):
raise _wrap_exception(ex, AzureSigningError)

def on_request(self, request):
# type: (PipelineRequest) -> Union[None, Awaitable[None]]
# type: (PipelineRequest) -> None
self.sign_request(request)

def sign_request(self, request):
Expand Down
23 changes: 12 additions & 11 deletions sdk/tables/azure-data-tables/azure/data/tables/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# license information.
# --------------------------------------------------------------------------

from typing import Dict, Optional, Any, List, Mapping
from typing import Dict, Optional, Any, List, Mapping, Union
from uuid import uuid4
try:
from urllib.parse import parse_qs, quote, urlparse
Expand Down Expand Up @@ -51,6 +51,7 @@
)
from ._sdk_moniker import SDK_MONIKER


_SUPPORTED_API_VERSIONS = ["2019-02-02", "2019-07-07"]


Expand Down Expand Up @@ -133,9 +134,9 @@ def __init__(
LocationMode.PRIMARY: primary_hostname,
LocationMode.SECONDARY: secondary_hostname,
}
self._credential_policy = None
self._configure_credential(self.credential)
self._policies = self._configure_policies(hosts=self._hosts, **kwargs)
self._credential_policy = None # type: ignore
self._configure_credential(self.credential) # type: ignore
self._policies = self._configure_policies(hosts=self._hosts, **kwargs) # type: ignore
if self._cosmos_endpoint:
self._policies.insert(0, CosmosPatchTransformPolicy())

Expand Down Expand Up @@ -203,7 +204,7 @@ class TablesBaseClient(AccountHostsMixin):
def __init__(
self,
endpoint, # type: str
credential=None, # type: str
credential=None, # type: Union[AzureNamedKeyCredential, AzureSasCredential]
**kwargs # type: Any
):
# type: (...) -> None
Expand Down Expand Up @@ -242,15 +243,15 @@ def _configure_policies(self, **kwargs):
def _configure_credential(self, credential):
# type: (Any) -> None
if hasattr(credential, "get_token"):
self._credential_policy = BearerTokenCredentialPolicy(
self._credential_policy = BearerTokenCredentialPolicy( # type: ignore
credential, STORAGE_OAUTH_SCOPE
)
elif isinstance(credential, SharedKeyCredentialPolicy):
self._credential_policy = credential
self._credential_policy = credential # type: ignore
elif isinstance(credential, AzureSasCredential):
self._credential_policy = AzureSasCredentialPolicy(credential)
self._credential_policy = AzureSasCredentialPolicy(credential) # type: ignore
elif isinstance(credential, AzureNamedKeyCredential):
self._credential_policy = SharedKeyCredentialPolicy(credential)
self._credential_policy = SharedKeyCredentialPolicy(credential) # type: ignore
elif credential is not None:
raise TypeError("Unsupported credential: {}".format(credential))

Expand All @@ -260,9 +261,9 @@ def _batch_send(self, *reqs, **kwargs):
# Pop it here, so requests doesn't feel bad about additional kwarg
policies = [StorageHeadersPolicy()]

changeset = HttpRequest("POST", None)
changeset = HttpRequest("POST", None) # type: ignore
changeset.set_multipart_mixed(
*reqs, policies=policies, boundary="changeset_{}".format(uuid4())
*reqs, policies=policies, boundary="changeset_{}".format(uuid4()) # type: ignore
)
request = self._client._client.post( # pylint: disable=protected-access
url="https://{}/$batch".format(self._primary_hostname),
Expand Down
16 changes: 9 additions & 7 deletions sdk/tables/azure-data-tables/azure/data/tables/_deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import Dict, Optional, Any
from typing import Union, Dict, Any, Optional

from uuid import UUID
import logging
import datetime
Expand Down Expand Up @@ -51,7 +52,7 @@ def _from_entity_binary(value):


def _from_entity_int32(value):
# type: (str) -> EntityProperty
# type: (str) -> int
return int(value)


Expand Down Expand Up @@ -96,11 +97,12 @@ def _from_entity_guid(value):


def _from_entity_str(value):
# type: (str) -> EntityProperty
if isinstance(six.binary_type):
# type: (Union[str, bytes]) -> str
if isinstance(value, six.binary_type):
return value.decode('utf-8')
return value


_EDM_TYPES = [
EdmType.BINARY,
EdmType.INT64,
Expand Down Expand Up @@ -180,7 +182,7 @@ def _convert_to_entity(entry_element):

# Add type for String
try:
if isinstance(value, unicode) and mtype is None:
if isinstance(value, unicode) and mtype is None: # type: ignore
mtype = EdmType.STRING
except NameError:
if isinstance(value, str) and mtype is None:
Expand Down Expand Up @@ -256,7 +258,7 @@ def _return_context_and_deserialized(


def _trim_service_metadata(metadata, content=None):
# type: (Dict[str,str], Optional[Dict[str, Any]]) -> None
# type: (Dict[str, str], Optional[Dict[str, Any]]) -> Dict[str, Any]
result = {
"date": metadata.pop("date", None),
"etag": metadata.pop("etag", None),
Expand All @@ -265,5 +267,5 @@ def _trim_service_metadata(metadata, content=None):
preference = metadata.pop('preference_applied', None)
if preference:
result["preference_applied"] = preference
result["content"] = content
result["content"] = content # type: ignore
return result
4 changes: 2 additions & 2 deletions sdk/tables/azure-data-tables/azure/data/tables/_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class TableEntity(dict):
An Entity dictionary with additional metadata
"""
_metadata = None
_metadata = {} # type: Dict[str, Any]

@property
def metadata(self):
# type: () -> Dict[str, Any]
"""Resets metadata to be a part of the entity
:return Dict of entity metadata
:rtype Dict[str, Any]
:rtype: Dict[str, Any]
"""
return self._metadata

Expand Down
29 changes: 14 additions & 15 deletions sdk/tables/azure-data-tables/azure/data/tables/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# license information.
# --------------------------------------------------------------------------
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, List

from azure.core.exceptions import HttpResponseError
from azure.core.paging import PageIterator
Expand All @@ -28,6 +28,7 @@

if TYPE_CHECKING:
from ._generated.models import TableQueryResponse
from ._generated.models import TableServiceProperties as GenTableServiceProperties


class TableServiceStats(GenTableServiceStats):
Expand Down Expand Up @@ -250,8 +251,8 @@ class CorsRule(GeneratedCorsRule):

def __init__( # pylint: disable=super-init-not-called
self,
allowed_origins, # type: list[str]
allowed_methods, # type: list[str]
allowed_origins, # type: List[str]
allowed_methods, # type: List[str]
**kwargs # type: Any
):
# type: (...)-> None
Expand Down Expand Up @@ -407,7 +408,7 @@ def __add__(self, other):
return TableSasPermissions(_str=str(self) + str(other))

def __str__(self):
# type: () -> TableSasPermissions
# type: () -> str
return (
("r" if self.read else "")
+ ("a" if self.add else "")
Expand Down Expand Up @@ -446,23 +447,19 @@ def from_string(
return parsed


TableSasPermissions.READ = TableSasPermissions(**dict(read=True))
TableSasPermissions.ADD = TableSasPermissions(**dict(add=True))
TableSasPermissions.UPDATE = TableSasPermissions(**dict(update=True))
TableSasPermissions.DELETE = TableSasPermissions(**dict(delete=True))


def service_stats_deserialize(generated):
# type: (GenTableServiceStats) -> Dict[str, Any]
"""Deserialize a ServiceStats objects into a dict."""
return {
"geo_replication": {
"status": generated.geo_replication.status,
"last_sync_time": generated.geo_replication.last_sync_time,
"status": generated.geo_replication.status, # type: ignore
"last_sync_time": generated.geo_replication.last_sync_time, # type: ignore
}
}


def service_properties_deserialize(generated):
# type: (GenTableServiceProperties) -> Dict[str, Any]
"""Deserialize a ServiceProperties objects into a dict."""
return {
"analytics_logging": TableAnalyticsLogging._from_generated(generated.logging), # pylint: disable=protected-access
Expand All @@ -473,7 +470,8 @@ def service_properties_deserialize(generated):
generated.minute_metrics
),
"cors": [
CorsRule._from_generated(cors) for cors in generated.cors # pylint: disable=protected-access
CorsRule._from_generated(cors) # pylint: disable=protected-access
for cors in generated.cors # type: ignore
],
}

Expand All @@ -493,10 +491,11 @@ def __init__(self, name, **kwargs): # pylint: disable=unused-argument
"""
self.name = name

# TODO: TableQueryResponse is not the correct type
@classmethod
def _from_generated(cls, generated, **kwargs):
# type: (TableQueryResponse, Dict[str, Any]) -> TableItem
return cls(generated.table_name, **kwargs)
return cls(generated.table_name, **kwargs) # type: ignore


class TablePayloadFormat(object):
Expand Down Expand Up @@ -643,7 +642,7 @@ def __str__(self):

@classmethod
def from_string(cls, permission, **kwargs):
# type: (str, Dict[str]) -> AccountSasPermissions
# type: (str, Dict[str, Any]) -> AccountSasPermissions
"""Create AccountSasPermissions from a string.
To specify read, write, delete, etc. permissions you need only to
Expand Down
4 changes: 2 additions & 2 deletions sdk/tables/azure-data-tables/azure/data/tables/_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def set_next_host_location(settings, request):
class StorageHeadersPolicy(HeadersPolicy):

def on_request(self, request):
# type: (PipelineRequest, Any) -> None
# type: (PipelineRequest) -> None
super(StorageHeadersPolicy, self).on_request(request)

# Add required date headers
Expand Down Expand Up @@ -236,6 +236,6 @@ class CosmosPatchTransformPolicy(SansIOHTTPPolicy):
"""Policy to transform PATCH requests into POST requests with the "X-HTTP-Method":"MERGE" header set."""

def on_request(self, request):
# type: (PipelineRequest) -> Union[None, Awaitable[None]]
# type: (PipelineRequest) -> None
if request.http_request.method == "PATCH":
_transform_patch_to_cosmos_post(request.http_request)
6 changes: 3 additions & 3 deletions sdk/tables/azure-data-tables/azure/data/tables/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def _to_entity_none(value): # pylint: disable=unused-argument
try:
_PYTHON_TO_ENTITY_CONVERSIONS.update(
{
unicode: _to_entity_str,
unicode: _to_entity_str, # type: ignore
str: _to_entity_binary,
long: _to_entity_int32,
long: _to_entity_int32, # type: ignore
}
)
except NameError:
Expand Down Expand Up @@ -198,7 +198,7 @@ def _add_entity_properties(source):

if isinstance(value, Enum):
try:
conv = _PYTHON_TO_ENTITY_CONVERSIONS.get(unicode)
conv = _PYTHON_TO_ENTITY_CONVERSIONS.get(unicode) # type: ignore
except NameError:
conv = _PYTHON_TO_ENTITY_CONVERSIONS.get(str)
mtype, value = conv(value)
Expand Down
Loading

0 comments on commit 816b025

Please sign in to comment.