diff --git a/sdk/servicebus/azure-servicebus/CHANGELOG.md b/sdk/servicebus/azure-servicebus/CHANGELOG.md index 77e523c803eb..482fab2a28c7 100644 --- a/sdk/servicebus/azure-servicebus/CHANGELOG.md +++ b/sdk/servicebus/azure-servicebus/CHANGELOG.md @@ -4,16 +4,16 @@ **New Features** -* Updated the following methods so that lists and single instances of dict representations are accepted for corresponding strongly-typed object arguments (PR #14807, thanks @bradleydamato): - - `update_queue`, `update_topic`, `update_subscription`, and `update_rule` on `ServiceBusAdministrationClient` accept dict representations of `QueueProperties`, `TopicProperties`, `SubscriptionProperties`, and `RuleProperties`, respectively. - - `send_messages` and `schedule_messages` on both sync and async versions of `ServiceBusSender` accept a list of or single instance of dict representations of `ServiceBusMessage`. - - `add_message` on `ServiceBusMessageBatch` now accepts a dict representation of `ServiceBusMessage`. - - Note: This is ongoing work and is the first step in supporting the above as respresentation of type `typing.Mapping`. +* Updated the following methods so that lists and single instances of Mapping representations are accepted for corresponding strongly-typed object arguments (PR #14807, thanks @bradleydamato): + - `update_queue`, `update_topic`, `update_subscription`, and `update_rule` on `ServiceBusAdministrationClient` accept Mapping representations of `QueueProperties`, `TopicProperties`, `SubscriptionProperties`, and `RuleProperties`, respectively. + - `send_messages` and `schedule_messages` on both sync and async versions of `ServiceBusSender` accept a list of or single instance of Mapping representations of `ServiceBusMessage`. + - `add_message` on `ServiceBusMessageBatch` now accepts a Mapping representation of `ServiceBusMessage`. **BugFixes** * Operations failing due to `uamqp.errors.LinkForceDetach` caused by no activity on the connection for 10 minutes will now be retried internally except for the session receiver case. * `uamqp.errors.AMQPConnectionError` errors with condition code `amqp:unknown-error` are now categorized into `ServiceBusConnectionError` instead of the general `ServiceBusError`. +* The `update_*` methods on `ServiceBusManagementClient` will now raise a `TypeError` rather than an `AttributeError` in the case of unsupported input type. ## 7.0.1 (2021-01-12) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/message.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/message.py index a9de43b03a26..15e308b7b6fe 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/message.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/message.py @@ -9,7 +9,7 @@ import uuid import logging import copy -from typing import Optional, List, Union, Iterable, TYPE_CHECKING, Any +from typing import Optional, List, Union, Iterable, TYPE_CHECKING, Any, Mapping, cast import six @@ -537,15 +537,8 @@ def __len__(self): def _from_list(self, messages, parent_span=None): # type: (Iterable[ServiceBusMessage], AbstractSpan) -> None - for each in messages: - if not isinstance(each, (ServiceBusMessage, dict)): - raise TypeError( - "Only ServiceBusMessage or an iterable object containing ServiceBusMessage " - "objects are accepted. Received instead: {}".format( - each.__class__.__name__ - ) - ) - self._add(each, parent_span) + for message in messages: + self._add(message, parent_span) @property def max_size_in_bytes(self): @@ -566,7 +559,7 @@ def size_in_bytes(self): return self._size def add_message(self, message): - # type: (ServiceBusMessage) -> None + # type: (Union[ServiceBusMessage, Mapping[str, Any]]) -> None """Try to add a single Message to the batch. The total size of an added message is the sum of its body, properties, etc. @@ -581,12 +574,12 @@ def add_message(self, message): return self._add(message) - def _add(self, message, parent_span=None): - # type: (ServiceBusMessage, AbstractSpan) -> None + def _add(self, add_message, parent_span=None): + # type: (Union[ServiceBusMessage, Mapping[str, Any]], AbstractSpan) -> None """Actual add implementation. The shim exists to hide the internal parameters such as parent_span.""" - - message = create_messages_from_dicts_if_needed(message, ServiceBusMessage) # type: ignore + message = create_messages_from_dicts_if_needed(add_message, ServiceBusMessage) message = transform_messages_to_sendable_if_needed(message) + message = cast(ServiceBusMessage, message) trace_message( message, parent_span ) # parent_span is e.g. if built as part of a send operation. diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py index 7c895aa43dcb..243f35c432eb 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/utils.py @@ -19,7 +19,8 @@ Optional, Type, TYPE_CHECKING, - Union + Union, + cast ) from contextlib import contextmanager from msrest.serialization import UTC @@ -59,19 +60,10 @@ from .receiver_mixins import ReceiverMixin from .._servicebus_session import BaseSession - # pylint: disable=unused-import, ungrouped-imports - DictMessageType = Union[ - Mapping, + MessagesType = Union[ + Mapping[str, Any], ServiceBusMessage, - List[Mapping[str, Any]], - List[ServiceBusMessage], - ServiceBusMessageBatch - ] - - DictMessageReturnType = Union[ - ServiceBusMessage, - List[ServiceBusMessage], - ServiceBusMessageBatch + List[Union[Mapping[str, Any], ServiceBusMessage]] ] _log = logging.getLogger(__name__) @@ -222,20 +214,37 @@ def transform_messages_to_sendable_if_needed(messages): except AttributeError: return messages + +def _single_message_from_dict(message, message_type): + # type: (Union[ServiceBusMessage, Mapping[str, Any]], Type[ServiceBusMessage]) -> ServiceBusMessage + if isinstance(message, message_type): + return message + try: + return message_type(**cast(Mapping[str, Any], message)) + except TypeError: + raise TypeError( + "Only ServiceBusMessage instances or Mappings representing messages are supported. " + "Received instead: {}".format( + message.__class__.__name__ + ) + ) + + def create_messages_from_dicts_if_needed(messages, message_type): - # type: (DictMessageType, type) -> DictMessageReturnType + # type: (MessagesType, Type[ServiceBusMessage]) -> Union[ServiceBusMessage, List[ServiceBusMessage]] """ - This method is used to convert dict representations - of messages to a list of ServiceBusMessage objects or ServiceBusBatchMessage. - :param DictMessageType messages: A list or single instance of messages of type ServiceBusMessages or - dict representations of type ServiceBusMessage. Also accepts ServiceBusBatchMessage. - :rtype: DictMessageReturnType + This method is used to convert dict representations of one or more messages to + one or more ServiceBusMessage objects. + + :param Messages messages: A list or single instance of messages of type ServiceBusMessage or + dict representations of type ServiceBusMessage. + :param Type[ServiceBusMessage] message_type: The class type to return the messages as. + :rtype: Union[ServiceBusMessage, List[ServiceBusMessage]] """ if isinstance(messages, list): - return [(message_type(**message) if isinstance(message, dict) else message) for message in messages] + return [_single_message_from_dict(m, message_type) for m in messages] + return _single_message_from_dict(messages, message_type) - return_messages = message_type(**messages) if isinstance(messages, dict) else messages - return return_messages def strip_protocol_from_uri(uri): # type: (str) -> str diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py index fee2b3d4ed91..5c18faa98b90 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_sender.py @@ -5,7 +5,7 @@ import logging import time import uuid -from typing import Any, TYPE_CHECKING, Union, List, Optional +from typing import Any, TYPE_CHECKING, Union, List, Optional, Mapping, cast import uamqp from uamqp import SendClient, types @@ -42,6 +42,16 @@ import datetime from azure.core.credentials import TokenCredential + MessageTypes = Union[ + Mapping[str, Any], + ServiceBusMessage, + List[Union[Mapping[str, Any], ServiceBusMessage]] + ] + MessageObjTypes = Union[ + ServiceBusMessage, + ServiceBusMessageBatch, + List[ServiceBusMessage]] + _LOGGER = logging.getLogger(__name__) @@ -248,7 +258,7 @@ def _send(self, message, timeout=None, last_exception=None): self._set_msg_timeout(default_timeout, None) def schedule_messages(self, messages, schedule_time_utc, **kwargs): - # type: (Union[ServiceBusMessage, List[ServiceBusMessage]], datetime.datetime, Any) -> List[int] + # type: (MessageTypes, datetime.datetime, Any) -> List[int] """Send Message or multiple Messages to be enqueued at a specific time. Returns a list of the sequence numbers of the enqueued messages. @@ -272,21 +282,21 @@ def schedule_messages(self, messages, schedule_time_utc, **kwargs): # pylint: disable=protected-access self._check_live() - messages = create_messages_from_dicts_if_needed(messages, ServiceBusMessage) # type: ignore + obj_messages = create_messages_from_dicts_if_needed(messages, ServiceBusMessage) timeout = kwargs.pop("timeout", None) if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") with send_trace_context_manager(span_name=SPAN_NAME_SCHEDULE) as send_span: - if isinstance(messages, ServiceBusMessage): + if isinstance(obj_messages, ServiceBusMessage): request_body = self._build_schedule_request( - schedule_time_utc, send_span, messages + schedule_time_utc, send_span, obj_messages ) else: - if len(messages) == 0: + if len(obj_messages) == 0: return [] # No-op on empty list. request_body = self._build_schedule_request( - schedule_time_utc, send_span, *messages + schedule_time_utc, send_span, *obj_messages ) if send_span: self._add_span_request_attributes(send_span) @@ -338,7 +348,7 @@ def cancel_scheduled_messages(self, sequence_numbers, **kwargs): ) def send_messages(self, message, **kwargs): - # type: (Union[ServiceBusMessage, ServiceBusMessageBatch, List[ServiceBusMessage]], Any) -> None + # type: (Union[MessageTypes, ServiceBusMessageBatch], Any) -> None """Sends message and blocks until acknowledgement is received or operation times out. If a list of messages was provided, attempts to send them as a single batch, throwing a @@ -368,48 +378,44 @@ def send_messages(self, message, **kwargs): :caption: Send message. """ - self._check_live() - message = create_messages_from_dicts_if_needed(message, ServiceBusMessage) timeout = kwargs.pop("timeout", None) if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") with send_trace_context_manager() as send_span: - # Ensure message is sendable (not a ReceivedMessage), and if needed (a list) is batched. Adds tracing. - message = transform_messages_to_sendable_if_needed(message) - try: - for each_message in iter(message): # type: ignore # Ignore type (and below) as it will except if wrong. - add_link_to_send(each_message, send_span) - batch = self.create_message_batch() - batch._from_list(message, send_span) # type: ignore # pylint: disable=protected-access - message = batch - except TypeError: # Message was not a list or generator. Do needed tracing. - if isinstance(message, ServiceBusMessageBatch): - for ( - batch_message - ) in message.message._body_gen: # pylint: disable=protected-access - add_link_to_send(batch_message, send_span) - elif isinstance(message, ServiceBusMessage): - trace_message(message, send_span) - add_link_to_send(message, send_span) + if isinstance(message, ServiceBusMessageBatch): + for ( + batch_message + ) in message.message._body_gen: # pylint: disable=protected-access + add_link_to_send(batch_message, send_span) + obj_message = message # type: MessageObjTypes + else: + obj_message = create_messages_from_dicts_if_needed(message, ServiceBusMessage) + # Ensure message is sendable (not a ReceivedMessage), and if needed (a list) is batched. Adds tracing. + obj_message = transform_messages_to_sendable_if_needed(obj_message) + try: + # Ignore type (and below) as it will except if wrong. + for each_message in iter(obj_message): # type: ignore + add_link_to_send(each_message, send_span) + batch = self.create_message_batch() + batch._from_list(obj_message, send_span) # type: ignore # pylint: disable=protected-access + obj_message = batch + except TypeError: # Message was not a list or generator. Do needed tracing. + trace_message(cast(ServiceBusMessage, obj_message), send_span) + add_link_to_send(obj_message, send_span) if ( - isinstance(message, ServiceBusMessageBatch) and len(message) == 0 + isinstance(obj_message, ServiceBusMessageBatch) and len(obj_message) == 0 ): # pylint: disable=len-as-condition return # Short circuit noop if an empty list or batch is provided. - if not isinstance(message, (ServiceBusMessageBatch, ServiceBusMessage)): - raise TypeError( - "Can only send azure.servicebus. " - "or lists of ServiceBusMessage." - ) if send_span: self._add_span_request_attributes(send_span) self._do_retryable_operation( self._send, - message=message, + message=obj_message, timeout=timeout, operation_requires_timeout=True, require_last_exception=True, diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py index 1e71b1255557..6881ed24b10f 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_sender_async.py @@ -5,7 +5,7 @@ import logging import asyncio import datetime -from typing import Any, TYPE_CHECKING, Union, List, Optional +from typing import Any, TYPE_CHECKING, Union, List, Optional, Mapping, cast import uamqp from uamqp import SendClientAsync, types @@ -32,6 +32,17 @@ if TYPE_CHECKING: from azure.core.credentials import TokenCredential + +MessageTypes = Union[ + Mapping[str, Any], + ServiceBusMessage, + List[Union[Mapping[str, Any], ServiceBusMessage]] +] +MessageObjTypes = Union[ + ServiceBusMessage, + ServiceBusMessageBatch, + List[ServiceBusMessage]] + _LOGGER = logging.getLogger(__name__) @@ -182,7 +193,7 @@ async def _send(self, message, timeout=None, last_exception=None): async def schedule_messages( self, - messages: Union[ServiceBusMessage, List[ServiceBusMessage]], + messages: MessageTypes, schedule_time_utc: datetime.datetime, **kwargs: Any ) -> List[int]: @@ -209,20 +220,20 @@ async def schedule_messages( # pylint: disable=protected-access self._check_live() - messages = create_messages_from_dicts_if_needed(messages, ServiceBusMessage) # type: ignore + obj_messages = create_messages_from_dicts_if_needed(messages, ServiceBusMessage) timeout = kwargs.pop("timeout", None) if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") with send_trace_context_manager(span_name=SPAN_NAME_SCHEDULE) as send_span: - if isinstance(messages, ServiceBusMessage): + if isinstance(obj_messages, ServiceBusMessage): request_body = self._build_schedule_request( - schedule_time_utc, send_span, messages + schedule_time_utc, send_span, obj_messages ) else: - if len(messages) == 0: + if len(obj_messages) == 0: return [] # No-op on empty list. request_body = self._build_schedule_request( - schedule_time_utc, send_span, *messages + schedule_time_utc, send_span, *obj_messages ) if send_span: await self._add_span_request_attributes(send_span) @@ -276,9 +287,7 @@ async def cancel_scheduled_messages( async def send_messages( self, - message: Union[ - ServiceBusMessage, ServiceBusMessageBatch, List[ServiceBusMessage] - ], + message: Union[MessageTypes, ServiceBusMessageBatch], **kwargs: Any ) -> None: """Sends message and blocks until acknowledgement is received or operation times out. @@ -312,44 +321,41 @@ async def send_messages( """ self._check_live() - message = create_messages_from_dicts_if_needed(message, ServiceBusMessage) timeout = kwargs.pop("timeout", None) if timeout is not None and timeout <= 0: raise ValueError("The timeout must be greater than 0.") with send_trace_context_manager() as send_span: - message = transform_messages_to_sendable_if_needed(message) - try: - for each_message in iter(message): # type: ignore # Ignore type (and below) as it will except if wrong. - add_link_to_send(each_message, send_span) - batch = await self.create_message_batch() - batch._from_list(message, send_span) # type: ignore # pylint: disable=protected-access - message = batch - except TypeError: # Message was not a list or generator. - if isinstance(message, ServiceBusMessageBatch): - for ( - batch_message - ) in message.message._body_gen: # pylint: disable=protected-access - add_link_to_send(batch_message, send_span) - elif isinstance(message, ServiceBusMessage): - trace_message(message, send_span) - add_link_to_send(message, send_span) + if isinstance(message, ServiceBusMessageBatch): + for ( + batch_message + ) in message.message._body_gen: # pylint: disable=protected-access + add_link_to_send(batch_message, send_span) + obj_message = message # type: MessageObjTypes + else: + obj_message = create_messages_from_dicts_if_needed(message, ServiceBusMessage) + obj_message = transform_messages_to_sendable_if_needed(obj_message) + try: + # Ignore type (and below) as it will except if wrong. + for each_message in iter(obj_message): # type: ignore + add_link_to_send(each_message, send_span) + batch = await self.create_message_batch() + batch._from_list(obj_message, send_span) # type: ignore # pylint: disable=protected-access + obj_message = batch + except TypeError: # Message was not a list or generator. + trace_message(cast(ServiceBusMessage, obj_message), send_span) + add_link_to_send(obj_message, send_span) if ( - isinstance(message, ServiceBusMessageBatch) and len(message) == 0 + isinstance(obj_message, ServiceBusMessageBatch) and len(obj_message) == 0 ): # pylint: disable=len-as-condition return # Short circuit noop if an empty list or batch is provided. - if not isinstance(message, (ServiceBusMessageBatch, ServiceBusMessage)): - raise TypeError( - "Can only send azure.servicebus. " - "or lists of ServiceBusMessage." - ) if send_span: await self._add_span_request_attributes(send_span) await self._do_retryable_operation( self._send, - message=message, + message=obj_message, timeout=timeout, operation_requires_timeout=True, require_last_exception=True, diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py index 8f774c947846..cc7e7e1e4c27 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/management/_management_client_async.py @@ -6,7 +6,7 @@ # pylint:disable=specify-parameter-names-in-call # pylint:disable=too-many-lines import functools -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Union, cast, Mapping from xml.etree.ElementTree import ElementTree from azure.core.async_paging import AsyncItemPaged @@ -399,7 +399,7 @@ async def create_queue(self, queue_name: str, **kwargs) -> QueueProperties: ) return result - async def update_queue(self, queue: QueueProperties, **kwargs) -> None: + async def update_queue(self, queue: Union[QueueProperties, Mapping[str, Any]], **kwargs) -> None: """Update a queue. Before calling this method, you should use `get_queue`, `create_queue` or `list_queues` to get a @@ -412,7 +412,7 @@ async def update_queue(self, queue: QueueProperties, **kwargs) -> None: :rtype: None """ - queue = create_properties_from_dict_if_needed(queue, QueueProperties) # type: ignore + queue = create_properties_from_dict_if_needed(queue, QueueProperties) to_update = queue._to_internal_entity() to_update.default_message_time_to_live = avoid_timedelta_overflow( @@ -626,7 +626,7 @@ async def create_topic(self, topic_name: str, **kwargs) -> TopicProperties: ) return result - async def update_topic(self, topic: TopicProperties, **kwargs) -> None: + async def update_topic(self, topic: Union[TopicProperties, Mapping[str, Any]], **kwargs) -> None: """Update a topic. Before calling this method, you should use `get_topic`, `create_topic` or `list_topics` to get a @@ -639,7 +639,7 @@ async def update_topic(self, topic: TopicProperties, **kwargs) -> None: :rtype: None """ - topic = create_properties_from_dict_if_needed(topic, TopicProperties) # type: ignore + topic = create_properties_from_dict_if_needed(topic, TopicProperties) to_update = topic._to_internal_entity() to_update.default_message_time_to_live = avoid_timedelta_overflow( @@ -872,7 +872,7 @@ async def create_subscription( return result async def update_subscription( - self, topic_name: str, subscription: SubscriptionProperties, **kwargs + self, topic_name: str, subscription: Union[SubscriptionProperties, Mapping[str, Any]], **kwargs ) -> None: """Update a subscription. @@ -887,7 +887,7 @@ async def update_subscription( _validate_entity_name_type(topic_name, display_name="topic_name") - subscription = create_properties_from_dict_if_needed(subscription, SubscriptionProperties) # type: ignore + subscription = create_properties_from_dict_if_needed(subscription, SubscriptionProperties) to_update = subscription._to_internal_entity() to_update.default_message_time_to_live = avoid_timedelta_overflow( @@ -1068,7 +1068,7 @@ async def create_rule( return result async def update_rule( - self, topic_name: str, subscription_name: str, rule: RuleProperties, **kwargs + self, topic_name: str, subscription_name: str, rule: Union[RuleProperties, Mapping[str, Any]], **kwargs ) -> None: """Update a rule. @@ -1085,7 +1085,7 @@ async def update_rule( """ _validate_topic_and_subscription_types(topic_name, subscription_name) - rule = create_properties_from_dict_if_needed(rule, RuleProperties) # type: ignore + rule = create_properties_from_dict_if_needed(rule, RuleProperties) to_update = rule._to_internal_entity() create_entity_body = CreateRuleBody( diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py index eb4500c6916c..38a679ef84f9 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_management_client.py @@ -6,7 +6,7 @@ # pylint:disable=specify-parameter-names-in-call # pylint:disable=too-many-lines import functools -from typing import TYPE_CHECKING, Dict, Any, Union, cast +from typing import TYPE_CHECKING, Dict, Any, Union, cast, Mapping from xml.etree.ElementTree import ElementTree from azure.core.paging import ItemPaged @@ -392,7 +392,7 @@ def create_queue(self, queue_name, **kwargs): return result def update_queue(self, queue, **kwargs): - # type: (QueueProperties, Any) -> None + # type: (Union[QueueProperties, Mapping], Any) -> None """Update a queue. Before calling this method, you should use `get_queue`, `create_queue` or `list_queues` to get a @@ -405,7 +405,7 @@ def update_queue(self, queue, **kwargs): :rtype: None """ - queue = create_properties_from_dict_if_needed(queue, QueueProperties) # type: ignore + queue = create_properties_from_dict_if_needed(queue, QueueProperties) to_update = queue._to_internal_entity() to_update.default_message_time_to_live = avoid_timedelta_overflow( @@ -621,7 +621,7 @@ def create_topic(self, topic_name, **kwargs): return result def update_topic(self, topic, **kwargs): - # type: (TopicProperties, Any) -> None + # type: (Union[TopicProperties, Mapping[str, Any]], Any) -> None """Update a topic. Before calling this method, you should use `get_topic`, `create_topic` or `list_topics` to get a @@ -634,7 +634,7 @@ def update_topic(self, topic, **kwargs): :rtype: None """ - topic = create_properties_from_dict_if_needed(topic, TopicProperties) # type: ignore + topic = create_properties_from_dict_if_needed(topic, TopicProperties) to_update = topic._to_internal_entity() to_update.default_message_time_to_live = ( @@ -876,7 +876,7 @@ def create_subscription(self, topic_name, subscription_name, **kwargs): return result def update_subscription(self, topic_name, subscription, **kwargs): - # type: (str, SubscriptionProperties, Any) -> None + # type: (str, Union[SubscriptionProperties, Mapping[str, Any]], Any) -> None """Update a subscription. Before calling this method, you should use `get_subscription`, `update_subscription` or `list_subscription` @@ -1065,7 +1065,7 @@ def create_rule(self, topic_name, subscription_name, rule_name, **kwargs): return result def update_rule(self, topic_name, subscription_name, rule, **kwargs): - # type: (str, str, RuleProperties, Any) -> None + # type: (str, str, Union[RuleProperties, Mapping[str, Any]], Any) -> None """Update a rule. Before calling this method, you should use `get_rule`, `create_rule` or `list_rules` to get a `RuleProperties` @@ -1082,7 +1082,7 @@ def update_rule(self, topic_name, subscription_name, rule, **kwargs): """ _validate_topic_and_subscription_types(topic_name, subscription_name) - rule = create_properties_from_dict_if_needed(rule, RuleProperties) # type: ignore + rule = create_properties_from_dict_if_needed(rule, RuleProperties) to_update = rule._to_internal_entity() create_entity_body = CreateRuleBody( diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_utils.py b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_utils.py index 4d106036104f..9a1f65741e27 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/management/_utils.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/management/_utils.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- from datetime import datetime, timedelta -from typing import TYPE_CHECKING, cast, Union, Mapping +from typing import TYPE_CHECKING, cast, Union, Mapping, Type, Any from xml.etree.ElementTree import ElementTree, SubElement, QName import isodate import six @@ -12,22 +12,17 @@ from ._handle_response_error import _handle_response_error if TYPE_CHECKING: # pylint: disable=unused-import, ungrouped-imports + from typing import TypeVar from ._models import QueueProperties, TopicProperties, \ SubscriptionProperties, RuleProperties, InternalQueueDescription, InternalTopicDescription, \ InternalSubscriptionDescription, InternalRuleDescription - DictPropertiesType = Union[ - QueueProperties, - TopicProperties, - SubscriptionProperties, - RuleProperties, - Mapping - ] - DictPropertiesReturnType = Union[ + PropertiesType = TypeVar( + 'PropertiesType', QueueProperties, TopicProperties, SubscriptionProperties, RuleProperties - ] + ) # Refer to the async version of this module under ..\aio\management\_utils.py for detailed explanation. @@ -326,14 +321,24 @@ def _validate_topic_subscription_and_rule_types( ) def create_properties_from_dict_if_needed(properties, sb_resource_type): - # type: (DictPropertiesType, type) -> DictPropertiesReturnType + # type: (Union[PropertiesType, Mapping[str, Any]], Type[PropertiesType]) -> PropertiesType """ This method is used to create a properties object given the resource properties type and its corresponding dict representation. :param properties: A properties object or its dict representation. - :type properties: DictPropertiesType + :type properties: Mapping or PropertiesType :param type sb_resource_type: The type of properties object. - :rtype: DictPropertiesReturnType + :rtype: PropertiesType """ - return_properties = sb_resource_type(**properties) if isinstance(properties, dict) else properties - return return_properties + if isinstance(properties, sb_resource_type): + return properties + try: + return sb_resource_type(**cast(Mapping[str, Any], properties)) + except TypeError as e: + if "required keyword arguments" in str(e): + raise e + raise TypeError( + "Update input must be an instance of {}, or a mapping representing one.".format( + sb_resource_type.__name__ + ) + ) diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_queues_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_queues_async.py index 6422dd046823..7b81199792d8 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_queues_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_queues_async.py @@ -329,11 +329,11 @@ async def test_async_mgmt_queue_update_invalid(self, servicebus_namespace_connec queue_description = await mgmt_service.create_queue(queue_name) try: # handle a null update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): await mgmt_service.update_queue(None) # handle an invalid type update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): await mgmt_service.update_queue(Exception("test")) # change a setting we can't change; should fail. diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_rules_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_rules_async.py index 5c57c152da4d..9112775ccba6 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_rules_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_rules_async.py @@ -171,11 +171,11 @@ async def test_async_mgmt_rule_update_invalid(self, servicebus_namespace_connect rule_desc = await mgmt_service.get_rule(topic_name, subscription_name, rule_name) # handle a null update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): await mgmt_service.update_rule(topic_name, subscription_name, None) # handle an invalid type update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): await mgmt_service.update_rule(topic_name, subscription_name, Exception("test")) # change the name to a topic that doesn't exist; should fail. diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_subscriptions_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_subscriptions_async.py index 4fcafe344665..57568d833fd4 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_subscriptions_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_subscriptions_async.py @@ -156,11 +156,11 @@ async def test_async_mgmt_subscription_update_invalid(self, servicebus_namespace subscription_description = await mgmt_service.create_subscription(topic_name, subscription_name) # handle a null update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): await mgmt_service.update_subscription(topic_name, None) # handle an invalid type update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): await mgmt_service.update_subscription(topic_name, Exception("test")) # change the name to a topic that doesn't exist; should fail. diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_topics_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_topics_async.py index b67b5df5dd5b..3e5cd988ed39 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_topics_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/mgmt_tests/test_mgmt_topics_async.py @@ -136,11 +136,11 @@ async def test_async_mgmt_topic_update_invalid(self, servicebus_namespace_connec topic_description = await mgmt_service.create_topic(topic_name) # handle a null update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): await mgmt_service.update_topic(None) # handle an invalid type update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): await mgmt_service.update_topic(Exception("test")) # change the name to a topic that doesn't exist; should fail. diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_queues_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_queues_async.py index bab61acd00a7..55dd120cf5e1 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_queues_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_queues_async.py @@ -31,6 +31,7 @@ ) from azure.servicebus._common.constants import ServiceBusReceiveMode, ServiceBusSubQueue from azure.servicebus._common.utils import utc_now +from azure.servicebus.management._models import DictMixin from azure.servicebus.exceptions import ( ServiceBusConnectionError, ServiceBusError, @@ -1785,6 +1786,53 @@ async def test_queue_async_send_dict_messages(self, servicebus_namespace_connect received_messages.append(message) assert len(received_messages) == 6 + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='servicebustest') + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @ServiceBusQueuePreparer(name_prefix='servicebustest') + async def test_queue_async_send_mapping_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + class MappingMessage(DictMixin): + def __init__(self, content): + self.body = content + self.message_id = 'foo' + + class BadMappingMessage(DictMixin): + def __init__(self): + self.message_id = 'foo' + + async with ServiceBusClient.from_connection_string( + servicebus_namespace_connection_string, logging_enable=False) as sb_client: + + async with sb_client.get_queue_sender(servicebus_queue.name) as sender: + + message_dict = MappingMessage("Message") + message2_dict = MappingMessage("Message2") + message3_dict = BadMappingMessage() + list_message_dicts = [message_dict, message2_dict] + + # send single dict + await sender.send_messages(message_dict) + + # send list of dicts + await sender.send_messages(list_message_dicts) + + # send bad dict + with pytest.raises(TypeError): + await sender.send_messages(message3_dict) + + # create and send BatchMessage with dicts + batch_message = await sender.create_message_batch() + batch_message._from_list(list_message_dicts) # pylint: disable=protected-access + batch_message.add_message(message_dict) + await sender.send_messages(batch_message) + + received_messages = [] + async with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) as receiver: + async for message in receiver: + received_messages.append(message) + assert len(received_messages) == 6 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedResourceGroupPreparer(name_prefix='servicebustest') diff --git a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_queues.py b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_queues.py index a7c96031a227..45c49852bf9d 100644 --- a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_queues.py +++ b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_queues.py @@ -349,11 +349,11 @@ def test_mgmt_queue_update_invalid(self, servicebus_namespace_connection_string, queue_description = mgmt_service.create_queue(queue_name) try: # handle a null update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): mgmt_service.update_queue(None) # handle an invalid type update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): mgmt_service.update_queue(Exception("test")) # change a setting we can't change; should fail. diff --git a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_rules.py b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_rules.py index 0e4d00fdcb35..33bfa6c88965 100644 --- a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_rules.py +++ b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_rules.py @@ -188,11 +188,11 @@ def test_mgmt_rule_update_invalid(self, servicebus_namespace_connection_string, rule_desc = mgmt_service.get_rule(topic_name, subscription_name, rule_name) # handle a null update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): mgmt_service.update_rule(topic_name, subscription_name, None) # handle an invalid type update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): mgmt_service.update_rule(topic_name, subscription_name, Exception("test")) # change the name to a topic that doesn't exist; should fail. diff --git a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_subscriptions.py b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_subscriptions.py index a5833a7a5eee..5afff87b481c 100644 --- a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_subscriptions.py +++ b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_subscriptions.py @@ -155,11 +155,11 @@ def test_mgmt_subscription_update_invalid(self, servicebus_namespace_connection_ subscription_description = mgmt_service.create_subscription(topic_name, subscription_name) # handle a null update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): mgmt_service.update_subscription(topic_name, None) # handle an invalid type update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): mgmt_service.update_subscription(topic_name, Exception("test")) # change the name to a topic that doesn't exist; should fail. diff --git a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_topics.py b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_topics.py index 292c31a6ff0a..ee001a94bd9d 100644 --- a/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_topics.py +++ b/sdk/servicebus/azure-servicebus/tests/mgmt_tests/test_mgmt_topics.py @@ -136,11 +136,11 @@ def test_mgmt_topic_update_invalid(self, servicebus_namespace_connection_string, topic_description = mgmt_service.create_topic(topic_name) # handle a null update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): mgmt_service.update_topic(None) # handle an invalid type update properly. - with pytest.raises(AttributeError): + with pytest.raises(TypeError): mgmt_service.update_topic(Exception("test")) # change the name to a topic that doesn't exist; should fail. diff --git a/sdk/servicebus/azure-servicebus/tests/test_queues.py b/sdk/servicebus/azure-servicebus/tests/test_queues.py index 820a63de50ba..c26ad18f27c3 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_queues.py +++ b/sdk/servicebus/azure-servicebus/tests/test_queues.py @@ -34,6 +34,7 @@ _X_OPT_SCHEDULED_ENQUEUE_TIME ) from azure.servicebus._common.utils import utc_now +from azure.servicebus.management._models import DictMixin from azure.servicebus.exceptions import ( ServiceBusConnectionError, ServiceBusError, @@ -2216,6 +2217,53 @@ def test_queue_send_dict_messages(self, servicebus_namespace_connection_string, received_messages.append(message) assert len(received_messages) == 6 + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='servicebustest') + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @ServiceBusQueuePreparer(name_prefix='servicebustest') + def test_queue_send_mapping_messages(self, servicebus_namespace_connection_string, servicebus_queue, **kwargs): + class MappingMessage(DictMixin): + def __init__(self, content): + self.body = content + self.message_id = 'foo' + + class BadMappingMessage(DictMixin): + def __init__(self): + self.message_id = 'foo' + + with ServiceBusClient.from_connection_string( + servicebus_namespace_connection_string, logging_enable=False) as sb_client: + + with sb_client.get_queue_sender(servicebus_queue.name) as sender: + + message_dict = MappingMessage("Message") + message2_dict = MappingMessage("Message2") + message3_dict = BadMappingMessage() + list_message_dicts = [message_dict, message2_dict] + + # send single dict + sender.send_messages(message_dict) + + # send list of dicts + sender.send_messages(list_message_dicts) + + # send bad dict + with pytest.raises(TypeError): + sender.send_messages(message3_dict) + + # create and send BatchMessage with dicts + batch_message = sender.create_message_batch() + batch_message._from_list(list_message_dicts) # pylint: disable=protected-access + batch_message.add_message(message_dict) + sender.send_messages(batch_message) + + received_messages = [] + with sb_client.get_queue_receiver(servicebus_queue.name, max_wait_time=5) as receiver: + for message in receiver: + received_messages.append(message) + assert len(received_messages) == 6 + @pytest.mark.liveTest @pytest.mark.live_test_only @CachedResourceGroupPreparer(name_prefix='servicebustest')