Skip to content

Commit

Permalink
[Monitor][Query] Improve typing (Azure#28175)
Browse files Browse the repository at this point in the history
[Monitor][Query] Improve typing

This enables the mypy, pyright, verifytypes checks in the CI, and also
adds some typing improvements in order to pass the checks..

* Use class attribute style typing
* Class ordering in models file was changed a bit to allow
  for class attribute typing.

Signed-off-by: Paul Van Eck <[email protected]>
  • Loading branch information
pvaneck authored Jan 11, 2023
1 parent c7ec3d1 commit 83fca5e
Show file tree
Hide file tree
Showing 11 changed files with 446 additions and 419 deletions.
2 changes: 1 addition & 1 deletion sdk/monitor/azure-monitor-query/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
2 changes: 1 addition & 1 deletion sdk/monitor/azure-monitor-query/azure/monitor/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
39 changes: 23 additions & 16 deletions sdk/monitor/azure-monitor-query/azure/monitor/query/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,40 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import Any
import sys
from typing import Any, List, Optional

from ._models import LogsQueryStatus

if sys.version_info >= (3, 9):
from collections.abc import MutableMapping
else:
from typing import MutableMapping # pylint: disable=ungrouped-imports


class LogsQueryError(object):
"""The code and message for an error.
JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object

:ivar code: A machine readable error code.
:vartype code: str
:ivar message: A human readable error message.
:vartype message: str
:ivar details: A list of additional details about the error.
:vartype details: list[JSON]
:ivar status: status for error item when iterating over list of
results. Always "Failure" for an instance of a LogsQueryError.
:vartype status: ~azure.monitor.query.LogsQueryStatus
"""

class LogsQueryError:
"""The code and message for an error."""

code: str
"""A machine readable error code."""
message: str
"""A human readable error message."""
details: Optional[List[JSON]] = None
"""A list of additional details about the error."""
status: LogsQueryStatus
"""Status for error item when iterating over list of results. Always "Failure" for an instance of a
LogsQueryError."""

def __init__(self, **kwargs: Any) -> None:
self.code = kwargs.get("code", None)
self.message = kwargs.get("message", None)
self.code = kwargs.get("code", "")
self.message = kwargs.get("message", "")
self.details = kwargs.get("details", None)
self.status = LogsQueryStatus.FAILURE

def __str__(self):
def __str__(self) -> str:
return str(self.__dict__)

@classmethod
Expand Down
49 changes: 28 additions & 21 deletions sdk/monitor/azure-monitor-query/azure/monitor/query/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# license information.
# --------------------------------------------------------------------------
from datetime import datetime, timedelta
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional

from azure.core.credentials import TokenCredential
from azure.core.exceptions import HttpResponseError
Expand All @@ -16,7 +16,7 @@

def get_authentication_policy(
credential: TokenCredential,
audience: str = None
audience: Optional[str] = None
) -> BearerTokenCredentialPolicy:
"""Returns the correct authentication policy"""
if not audience:
Expand All @@ -34,7 +34,7 @@ def get_authentication_policy(

def get_metrics_authentication_policy(
credential: TokenCredential,
audience: str = None
audience: Optional[str] = None
) -> BearerTokenCredentialPolicy:
"""Returns the correct authentication policy"""
if not audience:
Expand All @@ -55,28 +55,34 @@ def order_results(request_order: List, mapping: Dict[str, Any], **kwargs: Any) -
results = []
for item in ordered:
if not item["body"].get("error"):
results.append(
kwargs.get("obj")._from_generated(item["body"]) # pylint: disable=protected-access
)
result_obj = kwargs.get("obj")
if result_obj:
results.append(
result_obj._from_generated(item["body"]) # pylint: disable=protected-access
)
else:
error = item["body"]["error"]
if error.get("code") == "PartialError":
res = kwargs.get("partial_err")._from_generated( # pylint: disable=protected-access
item["body"], kwargs.get("raise_with")
)
results.append(res)
partial_err = kwargs.get("partial_err")
if partial_err:
res = partial_err._from_generated( # pylint: disable=protected-access
item["body"], kwargs.get("raise_with")
)
results.append(res)
else:
results.append(
kwargs.get("err")._from_generated(error) # pylint: disable=protected-access
)
err = kwargs.get("err")
if err:
results.append(
err._from_generated(error) # pylint: disable=protected-access
)
return results


def construct_iso8601(timespan=None):
def construct_iso8601(timespan=None) -> Optional[str]:
if not timespan:
return None
start, end, duration = None, None, None
try:
start, end, duration = None, None, None
if isinstance(timespan[1], datetime): # we treat thi as start_time, end_time
start, end = timespan[0], timespan[1]
elif isinstance(
Expand All @@ -89,25 +95,26 @@ def construct_iso8601(timespan=None):
)
except TypeError:
duration = timespan # it means only duration (timedelta) is provideds
duration_str = ""
if duration:
try:
duration = "PT{}S".format(duration.total_seconds())
duration_str = "PT{}S".format(duration.total_seconds())
except AttributeError:
raise ValueError("timespan must be a timedelta or a tuple.")
iso_str = None
if start is not None:
start = Serializer.serialize_iso(start)
if end is not None:
end = Serializer.serialize_iso(end)
iso_str = start + "/" + end
elif duration is not None:
iso_str = start + "/" + duration
iso_str = f"{start}/{end}"
elif duration_str:
iso_str = f"{start}/{duration_str}"
else: # means that an invalid value None that is provided with start_time
raise ValueError(
"Duration or end_time cannot be None when provided with start_time."
)
else:
iso_str = duration
iso_str = duration_str
return iso_str


Expand All @@ -124,7 +131,7 @@ def native_col_type(col_type, value):
return value


def process_row(col_types, row):
def process_row(col_types, row) -> List[Any]:
return [native_col_type(col_types[ind], val) for ind, val in enumerate(row)]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def query_workspace(
:dedent: 0
:caption: Get a response for a single Log Query
"""
timespan = construct_iso8601(timespan)
timespan_iso = construct_iso8601(timespan)
include_statistics = kwargs.pop("include_statistics", False)
include_visualization = kwargs.pop("include_visualization", False)
server_timeout = kwargs.pop("server_timeout", None)
Expand All @@ -119,7 +119,7 @@ def query_workspace(

body = {
"query": query,
"timespan": timespan,
"timespan": timespan_iso,
"workspaces": additional_workspaces
}

Expand All @@ -131,7 +131,8 @@ def query_workspace(
)
except HttpResponseError as err:
process_error(err, LogsQueryError)
response = None

response: Union[LogsQueryResult, LogsQueryPartialResult]
if not generated_response.get("error"):
response = LogsQueryResult._from_generated( # pylint: disable=protected-access
generated_response
Expand All @@ -140,7 +141,7 @@ def query_workspace(
response = LogsQueryPartialResult._from_generated( # pylint: disable=protected-access
generated_response, LogsQueryError
)
return cast(Union[LogsQueryResult, LogsQueryPartialResult], response)
return response

@distributed_trace
def query_batch(
Expand Down Expand Up @@ -200,5 +201,5 @@ def __enter__(self) -> "LogsQueryClient":
self._client.__enter__() # pylint:disable=no-member
return self

def __exit__(self, *args) -> None:
def __exit__(self, *args: Any) -> None:
self._client.__exit__(*args) # pylint:disable=no-member
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# license information.
# --------------------------------------------------------------------------
# pylint: disable=anomalous-backslash-in-string
from typing import Any, List
from typing import Any, cast, List

from azure.core.credentials import TokenCredential
from azure.core.paging import ItemPaged
Expand Down Expand Up @@ -140,7 +140,7 @@ def list_metric_namespaces(self, resource_uri: str, **kwargs: Any) -> ItemPaged[
start_time = kwargs.pop("start_time", None)
if start_time:
start_time = Serializer.serialize_iso(start_time)
return self._namespace_op.list(
res = self._namespace_op.list(
resource_uri,
start_time=start_time,
cls=kwargs.pop(
Expand All @@ -152,6 +152,7 @@ def list_metric_namespaces(self, resource_uri: str, **kwargs: Any) -> ItemPaged[
),
**kwargs
)
return cast(ItemPaged[MetricNamespace], res)

@distributed_trace
def list_metric_definitions(self, resource_uri: str, **kwargs: Any) -> ItemPaged[MetricDefinition]:
Expand All @@ -166,7 +167,7 @@ def list_metric_definitions(self, resource_uri: str, **kwargs: Any) -> ItemPaged
:raises: ~azure.core.exceptions.HttpResponseError
"""
metric_namespace = kwargs.pop("namespace", None)
return self._definitions_op.list(
res = self._definitions_op.list(
resource_uri,
metricnamespace=metric_namespace,
cls=kwargs.pop(
Expand All @@ -178,6 +179,7 @@ def list_metric_definitions(self, resource_uri: str, **kwargs: Any) -> ItemPaged
),
**kwargs
)
return cast(ItemPaged[MetricDefinition], res)

def close(self) -> None:
"""Close the :class:`~azure.monitor.query.MetricsQueryClient` session."""
Expand Down
Loading

0 comments on commit 83fca5e

Please sign in to comment.