Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Monitor][Query] Improve typing #28175

Merged
merged 4 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/monitor/azure-monitor-query/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,40 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import Any
import sys
from typing import Any, List, Optional

from ._models import LogsQueryStatus

if sys.version_info >= (3, 9):
from collections.abc import MutableMapping
else:
from typing import MutableMapping # pylint: disable=ungrouped-imports


class LogsQueryError(object):
"""The code and message for an error.
JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object

:ivar code: A machine readable error code.
:vartype code: str
:ivar message: A human readable error message.
:vartype message: str
:ivar details: A list of additional details about the error.
:vartype details: list[JSON]
:ivar status: status for error item when iterating over list of
results. Always "Failure" for an instance of a LogsQueryError.
:vartype status: ~azure.monitor.query.LogsQueryStatus
"""

class LogsQueryError:
"""The code and message for an error."""

code: str
"""A machine readable error code."""
message: str
"""A human readable error message."""
details: Optional[List[JSON]] = None
"""A list of additional details about the error."""
status: LogsQueryStatus
"""Status for error item when iterating over list of results. Always "Failure" for an instance of a
LogsQueryError."""

def __init__(self, **kwargs: Any) -> None:
self.code = kwargs.get("code", None)
self.message = kwargs.get("message", None)
self.code = kwargs.get("code", "")
self.message = kwargs.get("message", "")
self.details = kwargs.get("details", None)
self.status = LogsQueryStatus.FAILURE

def __str__(self):
def __str__(self) -> str:
return str(self.__dict__)

@classmethod
Expand Down
49 changes: 28 additions & 21 deletions sdk/monitor/azure-monitor-query/azure/monitor/query/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# license information.
# --------------------------------------------------------------------------
from datetime import datetime, timedelta
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional

from azure.core.credentials import TokenCredential
from azure.core.exceptions import HttpResponseError
Expand All @@ -16,7 +16,7 @@

def get_authentication_policy(
credential: TokenCredential,
audience: str = None
audience: Optional[str] = None
) -> BearerTokenCredentialPolicy:
"""Returns the correct authentication policy"""
if not audience:
Expand All @@ -34,7 +34,7 @@ def get_authentication_policy(

def get_metrics_authentication_policy(
credential: TokenCredential,
audience: str = None
audience: Optional[str] = None
) -> BearerTokenCredentialPolicy:
"""Returns the correct authentication policy"""
if not audience:
Expand All @@ -55,28 +55,34 @@ def order_results(request_order: List, mapping: Dict[str, Any], **kwargs: Any) -
results = []
for item in ordered:
if not item["body"].get("error"):
results.append(
kwargs.get("obj")._from_generated(item["body"]) # pylint: disable=protected-access
)
result_obj = kwargs.get("obj")
if result_obj:
results.append(
result_obj._from_generated(item["body"]) # pylint: disable=protected-access
)
else:
error = item["body"]["error"]
if error.get("code") == "PartialError":
res = kwargs.get("partial_err")._from_generated( # pylint: disable=protected-access
item["body"], kwargs.get("raise_with")
)
results.append(res)
partial_err = kwargs.get("partial_err")
if partial_err:
res = partial_err._from_generated( # pylint: disable=protected-access
item["body"], kwargs.get("raise_with")
)
results.append(res)
else:
results.append(
kwargs.get("err")._from_generated(error) # pylint: disable=protected-access
)
err = kwargs.get("err")
if err:
results.append(
err._from_generated(error) # pylint: disable=protected-access
)
return results


def construct_iso8601(timespan=None):
def construct_iso8601(timespan=None) -> Optional[str]:
if not timespan:
return None
start, end, duration = None, None, None
try:
start, end, duration = None, None, None
if isinstance(timespan[1], datetime): # we treat thi as start_time, end_time
start, end = timespan[0], timespan[1]
elif isinstance(
Expand All @@ -89,25 +95,26 @@ def construct_iso8601(timespan=None):
)
except TypeError:
duration = timespan # it means only duration (timedelta) is provideds
duration_str = ""
if duration:
try:
duration = "PT{}S".format(duration.total_seconds())
duration_str = "PT{}S".format(duration.total_seconds())
except AttributeError:
raise ValueError("timespan must be a timedelta or a tuple.")
iso_str = None
if start is not None:
start = Serializer.serialize_iso(start)
if end is not None:
end = Serializer.serialize_iso(end)
iso_str = start + "/" + end
elif duration is not None:
iso_str = start + "/" + duration
iso_str = f"{start}/{end}"
elif duration_str:
iso_str = f"{start}/{duration_str}"
else: # means that an invalid value None that is provided with start_time
raise ValueError(
"Duration or end_time cannot be None when provided with start_time."
)
else:
iso_str = duration
iso_str = duration_str
return iso_str


Expand All @@ -124,7 +131,7 @@ def native_col_type(col_type, value):
return value


def process_row(col_types, row):
def process_row(col_types, row) -> List[Any]:
return [native_col_type(col_types[ind], val) for ind, val in enumerate(row)]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def query_workspace(
:dedent: 0
:caption: Get a response for a single Log Query
"""
timespan = construct_iso8601(timespan)
timespan_iso = construct_iso8601(timespan)
include_statistics = kwargs.pop("include_statistics", False)
include_visualization = kwargs.pop("include_visualization", False)
server_timeout = kwargs.pop("server_timeout", None)
Expand All @@ -119,7 +119,7 @@ def query_workspace(

body = {
"query": query,
"timespan": timespan,
"timespan": timespan_iso,
"workspaces": additional_workspaces
}

Expand All @@ -131,7 +131,8 @@ def query_workspace(
)
except HttpResponseError as err:
process_error(err, LogsQueryError)
response = None

response: Union[LogsQueryResult, LogsQueryPartialResult]
if not generated_response.get("error"):
response = LogsQueryResult._from_generated( # pylint: disable=protected-access
generated_response
Expand All @@ -140,7 +141,7 @@ def query_workspace(
response = LogsQueryPartialResult._from_generated( # pylint: disable=protected-access
generated_response, LogsQueryError
)
return cast(Union[LogsQueryResult, LogsQueryPartialResult], response)
return response

@distributed_trace
def query_batch(
Expand Down Expand Up @@ -200,5 +201,5 @@ def __enter__(self) -> "LogsQueryClient":
self._client.__enter__() # pylint:disable=no-member
return self

def __exit__(self, *args) -> None:
def __exit__(self, *args: Any) -> None:
self._client.__exit__(*args) # pylint:disable=no-member
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# license information.
# --------------------------------------------------------------------------
# pylint: disable=anomalous-backslash-in-string
from typing import Any, List
from typing import Any, cast, List

from azure.core.credentials import TokenCredential
from azure.core.paging import ItemPaged
Expand Down Expand Up @@ -140,7 +140,7 @@ def list_metric_namespaces(self, resource_uri: str, **kwargs: Any) -> ItemPaged[
start_time = kwargs.pop("start_time", None)
if start_time:
start_time = Serializer.serialize_iso(start_time)
return self._namespace_op.list(
res = self._namespace_op.list(
resource_uri,
start_time=start_time,
cls=kwargs.pop(
Expand All @@ -152,6 +152,7 @@ def list_metric_namespaces(self, resource_uri: str, **kwargs: Any) -> ItemPaged[
),
**kwargs
)
return cast(ItemPaged[MetricNamespace], res)

@distributed_trace
def list_metric_definitions(self, resource_uri: str, **kwargs: Any) -> ItemPaged[MetricDefinition]:
Expand All @@ -166,7 +167,7 @@ def list_metric_definitions(self, resource_uri: str, **kwargs: Any) -> ItemPaged
:raises: ~azure.core.exceptions.HttpResponseError
"""
metric_namespace = kwargs.pop("namespace", None)
return self._definitions_op.list(
res = self._definitions_op.list(
resource_uri,
metricnamespace=metric_namespace,
cls=kwargs.pop(
Expand All @@ -178,6 +179,7 @@ def list_metric_definitions(self, resource_uri: str, **kwargs: Any) -> ItemPaged
),
**kwargs
)
return cast(ItemPaged[MetricDefinition], res)

def close(self) -> None:
"""Close the :class:`~azure.monitor.query.MetricsQueryClient` session."""
Expand Down
Loading