diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index a7d4e0fb075e..57074132f668 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -8,6 +8,8 @@ ### Bugs Fixed +- Fixed a bug that `EventHubProducerClient` could be reopened for sending events instead of encountering with `KeyError` when the client is previously closed (issue #21849). + ### Other Changes ## 5.6.1 (2021-10-06) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 39f44773202d..89e9b0cfcfd9 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -17,7 +17,11 @@ from ._common import EventDataBatch, EventData if TYPE_CHECKING: - from azure.core.credentials import TokenCredential, AzureSasCredential, AzureNamedKeyCredential + from azure.core.credentials import ( + TokenCredential, + AzureSasCredential, + AzureNamedKeyCredential, + ) SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]] @@ -143,7 +147,10 @@ def _start_producer(self, partition_id, send_timeout): or cast(EventHubProducer, self._producers[partition_id]).closed ): self._producers[partition_id] = self._create_producer( - partition_id=partition_id, send_timeout=send_timeout + partition_id=( + None if partition_id == ALL_PARTITIONS else partition_id + ), + send_timeout=send_timeout, ) def _create_producer(self, partition_id=None, send_timeout=None): @@ -261,14 +268,21 @@ def send_batch(self, event_data_batch, **kwargs): if isinstance(event_data_batch, EventDataBatch): if partition_id or partition_key: - raise TypeError("partition_id and partition_key should be None when sending an EventDataBatch " - "because type EventDataBatch itself may have partition_id or partition_key") + raise TypeError( + "partition_id and partition_key should be None when sending an EventDataBatch " + "because type EventDataBatch itself may have partition_id or partition_key" + ) to_send_batch = event_data_batch else: - to_send_batch = self.create_batch(partition_id=partition_id, partition_key=partition_key) - to_send_batch._load_events(event_data_batch) # pylint:disable=protected-access + to_send_batch = self.create_batch( + partition_id=partition_id, partition_key=partition_key + ) + to_send_batch._load_events( # pylint:disable=protected-access + event_data_batch + ) partition_id = ( - to_send_batch._partition_id or ALL_PARTITIONS # pylint:disable=protected-access + to_send_batch._partition_id # pylint:disable=protected-access + or ALL_PARTITIONS ) if len(to_send_batch) == 0: @@ -400,8 +414,8 @@ def close(self): """ with self._lock: - for producer in self._producers.values(): - if producer: - producer.close() - self._producers = {} + for pid in self._producers: + if self._producers[pid]: + self._producers[pid].close() # type: ignore + self._producers[pid] = None super(EventHubProducerClient, self)._close() diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index 728902461d60..79077f4e3d20 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential - from uamqp.constants import TransportType # pylint: disable=ungrouped-imports + from uamqp.constants import TransportType # pylint: disable=ungrouped-imports SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]] @@ -80,7 +80,9 @@ def __init__( self, fully_qualified_namespace: str, eventhub_name: str, - credential: Union["AsyncTokenCredential", AzureSasCredential, AzureNamedKeyCredential], + credential: Union[ + "AsyncTokenCredential", AzureSasCredential, AzureNamedKeyCredential + ], **kwargs ) -> None: super(EventHubProducerClient, self).__init__( @@ -145,7 +147,10 @@ async def _start_producer( or cast(EventHubProducer, self._producers[partition_id]).closed ): self._producers[partition_id] = self._create_producer( - partition_id=partition_id, send_timeout=send_timeout + partition_id=( + None if partition_id == ALL_PARTITIONS else partition_id + ), + send_timeout=send_timeout, ) def _create_producer( @@ -294,18 +299,25 @@ async def send_batch( if isinstance(event_data_batch, EventDataBatch): if partition_id or partition_key: - raise TypeError("partition_id and partition_key should be None when sending an EventDataBatch " - "because type EventDataBatch itself may have partition_id or partition_key") + raise TypeError( + "partition_id and partition_key should be None when sending an EventDataBatch " + "because type EventDataBatch itself may have partition_id or partition_key" + ) to_send_batch = event_data_batch else: - to_send_batch = await self.create_batch(partition_id=partition_id, partition_key=partition_key) - to_send_batch._load_events(event_data_batch) # pylint:disable=protected-access + to_send_batch = await self.create_batch( + partition_id=partition_id, partition_key=partition_key + ) + to_send_batch._load_events( # pylint:disable=protected-access + event_data_batch + ) if len(to_send_batch) == 0: return partition_id = ( - to_send_batch._partition_id or ALL_PARTITIONS # pylint:disable=protected-access + to_send_batch._partition_id # pylint:disable=protected-access + or ALL_PARTITIONS ) try: await cast(EventHubProducer, self._producers[partition_id]).send( @@ -431,7 +443,9 @@ async def close(self) -> None: """ async with self._lock: - for producer in self._producers.values(): - if producer: - await producer.close() + for pid in self._producers: + if self._producers[pid] is not None: + await self._producers[pid].close() # type: ignore + self._producers[pid] = None + await super(EventHubProducerClient, self)._close_async() diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py index 4eac51006c9f..e2c281d3a32a 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py @@ -189,15 +189,34 @@ async def test_send_and_receive_small_body_async(connstr_receivers, payload): async def test_send_partition_async(connstr_receivers): connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string(connection_str) + + async with client: + batch = await client.create_batch() + batch.add(EventData(b"Data")) + await client.send_batch(batch) + async with client: batch = await client.create_batch(partition_id="1") batch.add(EventData(b"Data")) await client.send_batch(batch) partition_0 = receivers[0].receive_message_batch(timeout=5000) - assert len(partition_0) == 0 partition_1 = receivers[1].receive_message_batch(timeout=5000) - assert len(partition_1) == 1 + assert len(partition_0) + len(partition_1) == 2 + + async with client: + batch = await client.create_batch() + batch.add(EventData(b"Data")) + await client.send_batch(batch) + + async with client: + batch = await client.create_batch(partition_id="1") + batch.add(EventData(b"Data")) + await client.send_batch(batch) + + partition_0 = receivers[0].receive_message_batch(timeout=5000) + partition_1 = receivers[1].receive_message_batch(timeout=5000) + assert len(partition_0) + len(partition_1) == 2 @pytest.mark.liveTest diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py index 6bc42d0ea3eb..0872b55ae85f 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py @@ -206,15 +206,34 @@ def test_send_and_receive_small_body(connstr_receivers, payload): def test_send_partition(connstr_receivers): connection_str, receivers = connstr_receivers client = EventHubProducerClient.from_connection_string(connection_str) + + with client: + batch = client.create_batch() + batch.add(EventData(b"Data")) + client.send_batch(batch) + with client: batch = client.create_batch(partition_id="1") batch.add(EventData(b"Data")) client.send_batch(batch) partition_0 = receivers[0].receive_message_batch(timeout=5000) - assert len(partition_0) == 0 partition_1 = receivers[1].receive_message_batch(timeout=5000) - assert len(partition_1) == 1 + assert len(partition_0) + len(partition_1) == 2 + + with client: + batch = client.create_batch() + batch.add(EventData(b"Data")) + client.send_batch(batch) + + with client: + batch = client.create_batch(partition_id="1") + batch.add(EventData(b"Data")) + client.send_batch(batch) + + partition_0 = receivers[0].receive_message_batch(timeout=5000) + partition_1 = receivers[1].receive_message_batch(timeout=5000) + assert len(partition_0) + len(partition_1) == 2 @pytest.mark.liveTest