diff --git a/eng/tox/mypy_hard_failure_packages.py b/eng/tox/mypy_hard_failure_packages.py index 7c5f0a33de2b..c3966c3bf110 100644 --- a/eng/tox/mypy_hard_failure_packages.py +++ b/eng/tox/mypy_hard_failure_packages.py @@ -11,5 +11,6 @@ "azure-servicebus", "azure-ai-textanalytics", "azure-ai-formrecognizer", - "azure-ai-metricsadvisor" + "azure-ai-metricsadvisor", + "azure-eventgrid", ] diff --git a/sdk/eventgrid/azure-eventgrid/azure/__init__.py b/sdk/eventgrid/azure-eventgrid/azure/__init__.py index 69e3be50dac4..0c36c2076ba0 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/__init__.py +++ b/sdk/eventgrid/azure-eventgrid/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) +__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_consumer.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_consumer.py index acc80d774810..532ee81570ed 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_consumer.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_consumer.py @@ -6,13 +6,13 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING +from typing import cast, TYPE_CHECKING import logging from ._models import CloudEvent, EventGridEvent if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports - from typing import Any + from typing import Any, Union _LOGGER = logging.getLogger(__name__) @@ -58,7 +58,7 @@ def decode_eventgrid_event(self, eventgrid_event, **kwargs): # pylint: disable=n eventgrid_event = EventGridEvent._from_json(eventgrid_event, encode) # pylint: disable=protected-access deserialized_event = EventGridEvent.deserialize(eventgrid_event) EventGridEvent._deserialize_data(deserialized_event, deserialized_event.event_type) # pylint: disable=protected-access - return deserialized_event + return cast(EventGridEvent, deserialized_event) except Exception as err: _LOGGER.error('Error: cannot deserialize event. Event does not have a valid format. \ Event must be a string, dict, or bytes following the CloudEvent schema.') diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py index 2bd67abd56e3..c9773a40336e 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py @@ -5,6 +5,7 @@ import hashlib import hmac import base64 +from typing import TYPE_CHECKING, Any try: from urllib.parse import quote except ImportError: @@ -16,8 +17,11 @@ from ._signature_credential_policy import EventGridSharedAccessSignatureCredentialPolicy from . import _constants as constants +if TYPE_CHECKING: + from datetime import datetime + def generate_shared_access_signature(topic_hostname, shared_access_key, expiration_date_utc, **kwargs): - # type: (str, str, datetime.Datetime, Any) -> str + # type: (str, str, datetime, Any) -> str """ Helper method to generate shared access signature given hostname, key, and expiration date. :param str topic_hostname: The topic endpoint to send the events to. Similar to .-1.eventgrid.azure.net @@ -82,7 +86,7 @@ def _get_authentication_policy(credential): return authentication_policy def _is_cloud_event(event): - # type: dict -> bool + # type: (Any) -> bool required = ('id', 'source', 'specversion', 'type') try: return all([_ in event for _ in required]) and event['specversion'] == "1.0" diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py index a304f488fa51..304b4aa19677 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- # pylint:disable=protected-access +from typing import Union, Any, Dict import datetime as dt import uuid import json @@ -87,6 +88,7 @@ def __init__(self, source, type, **kwargs): # pylint: disable=redefined-builtin @classmethod def _from_generated(cls, cloud_event, **kwargs): + # type: (Union[str, Dict, bytes], Any) -> CloudEvent generated = InternalCloudEvent.deserialize(cloud_event) if generated.additional_properties: extensions = dict(generated.additional_properties) diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_policies.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_policies.py index b786f280cdb0..5de3c5025249 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_policies.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_policies.py @@ -4,11 +4,14 @@ # license information. # -------------------------------------------------------------------------- import json +from typing import TYPE_CHECKING import logging from azure.core.pipeline.policies import SansIOHTTPPolicy _LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from azure.core.pipeline import PipelineRequest class CloudEventDistributedTracingPolicy(SansIOHTTPPolicy): """CloudEventDistributedTracingPolicy is a policy which adds distributed tracing informatiom diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py index 806924e603d2..b7648b587e61 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py @@ -5,7 +5,7 @@ # license information. # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast, Dict, List, Any, Union from azure.core.tracing.decorator import distributed_trace from azure.core.pipeline.policies import ( @@ -27,10 +27,12 @@ from ._generated._event_grid_publisher_client import EventGridPublisherClient as EventGridPublisherClientImpl from ._policies import CloudEventDistributedTracingPolicy from ._version import VERSION +from ._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports - from typing import Any, Union, Dict, List + from azure.core.credentials import AzureKeyCredential + from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential SendType = Union[ CloudEvent, EventGridEvent, @@ -42,6 +44,13 @@ List[Dict] ] +ListEventType = Union[ + List[CloudEvent], + List[EventGridEvent], + List[CustomEvent], + List[Dict] +] + class EventGridPublisherClient(object): """EventGrid Python Publisher Client. @@ -79,7 +88,7 @@ def _policies(credential, **kwargs): CustomHookPolicy(**kwargs), NetworkTraceLoggingPolicy(**kwargs), DistributedTracingPolicy(**kwargs), - CloudEventDistributedTracingPolicy(**kwargs), + CloudEventDistributedTracingPolicy(), HttpLoggingPolicy(**kwargs) ] return policies @@ -98,20 +107,24 @@ def send(self, events, **kwargs): :raises: :class:`ValueError`, when events do not follow specified SendType. """ if not isinstance(events, list): - events = [events] + events = cast(ListEventType, [events]) if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events): try: - events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access + events = [cast(CloudEvent, e)._to_generated(**kwargs) for e in events] # pylint: disable=protected-access except AttributeError: pass # means it's a dictionary kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8") - self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs) + self._client.publish_cloud_event_events( + self._topic_hostname, + cast(List[InternalCloudEvent], events), + **kwargs + ) elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events): kwargs.setdefault("content_type", "application/json; charset=utf-8") - self._client.publish_events(self._topic_hostname, events, **kwargs) + self._client.publish_events(self._topic_hostname, cast(List[InternalEventGridEvent], events), **kwargs) elif all(isinstance(e, CustomEvent) for e in events): - serialized_events = [dict(e) for e in events] - self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs) + serialized_events = [dict(e) for e in events] # type: ignore + self._client.publish_custom_event_events(self._topic_hostname, cast(List, serialized_events), **kwargs) else: raise ValueError("Event schema is not correct.") diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_signature_credential_policy.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_signature_credential_policy.py index 3ae58fbf4a08..d210fe70b626 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_signature_credential_policy.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_signature_credential_policy.py @@ -4,10 +4,15 @@ # license information. # ------------------------------------------------------------------------- +from typing import Any, TYPE_CHECKING import six from azure.core.pipeline.policies import SansIOHTTPPolicy +if TYPE_CHECKING: + from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential + + class EventGridSharedAccessSignatureCredentialPolicy(SansIOHTTPPolicy): """Adds a token header for the provided credential. :param credential: The credential used to authenticate requests. diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/aio/_publisher_client_async.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/aio/_publisher_client_async.py index 218dea4edfb0..b2644f7aaa43 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/aio/_publisher_client_async.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/aio/_publisher_client_async.py @@ -6,7 +6,7 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, Union, List, Dict +from typing import Any, Union, List, Dict, cast from azure.core.credentials import AzureKeyCredential from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.pipeline.policies import ( @@ -26,19 +26,27 @@ from .._models import CloudEvent, EventGridEvent, CustomEvent from .._helpers import _get_topic_hostname_only_fqdn, _get_authentication_policy, _is_cloud_event from .._generated.aio import EventGridPublisherClient as EventGridPublisherClientAsync +from .._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent from .._shared_access_signature_credential import EventGridSharedAccessSignatureCredential from .._version import VERSION SendType = Union[ - CloudEvent, - EventGridEvent, - CustomEvent, - Dict, - List[CloudEvent], - List[EventGridEvent], - List[CustomEvent], - List[Dict] - ] + CloudEvent, + EventGridEvent, + CustomEvent, + Dict, + List[CloudEvent], + List[EventGridEvent], + List[CustomEvent], + List[Dict] +] + +ListEventType = Union[ + List[CloudEvent], + List[EventGridEvent], + List[CustomEvent], + List[Dict] +] class EventGridPublisherClient(): """Asynchronous EventGrid Python Publisher Client. @@ -101,20 +109,34 @@ async def send( :raises: :class:`ValueError`, when events do not follow specified SendType. """ if not isinstance(events, list): - events = [events] + events = cast(ListEventType, [events]) if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events): try: - events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access + events = [ + cast(CloudEvent, e)._to_generated(**kwargs) for e in events # pylint: disable=protected-access + ] except AttributeError: pass # means it's a dictionary kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8") - await self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs) + await self._client.publish_cloud_event_events( + self._topic_hostname, + cast(List[InternalCloudEvent], events), + **kwargs + ) elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events): kwargs.setdefault("content_type", "application/json; charset=utf-8") - await self._client.publish_events(self._topic_hostname, events, **kwargs) + await self._client.publish_events( + self._topic_hostname, + cast(List[InternalEventGridEvent], events), + **kwargs + ) elif all(isinstance(e, CustomEvent) for e in events): - serialized_events = [dict(e) for e in events] - await self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs) + serialized_events = [dict(e) for e in events] # type: ignore + await self._client.publish_custom_event_events( + self._topic_hostname, + cast(List, serialized_events), + **kwargs + ) else: raise ValueError("Event schema is not correct.") diff --git a/sdk/eventgrid/azure-eventgrid/mypy.ini b/sdk/eventgrid/azure-eventgrid/mypy.ini new file mode 100644 index 000000000000..b8d3b2b62839 --- /dev/null +++ b/sdk/eventgrid/azure-eventgrid/mypy.ini @@ -0,0 +1,13 @@ +[mypy] +python_version = 3.7 +warn_return_any = True +warn_unused_configs = True +ignore_missing_imports = True + +# Per-module options: + +[mypy-azure.eventgrid._generated.*] +ignore_errors = True + +[mypy-azure.core.*] +ignore_errors = True