Skip to content

Commit

Permalink
[EventHubs] Fix bug in reusing EventHubProducerClient (#21927)
Browse files Browse the repository at this point in the history
* fix bug

* run black

* fix mypy

* fix pylint

* update changelog to be more clear

* Update sdk/eventhub/azure-eventhub/CHANGELOG.md

Co-authored-by: swathipil <[email protected]>

Co-authored-by: swathipil <[email protected]>
  • Loading branch information
yunhaoling and swathipil authored Jan 4, 2022
1 parent 08819ce commit 811cf03
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 26 deletions.
2 changes: 2 additions & 0 deletions sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

### Bugs Fixed

- Fixed a bug that `EventHubProducerClient` could be reopened for sending events instead of encountering with `KeyError` when the client is previously closed (issue #21849).

### Other Changes

## 5.6.1 (2021-10-06)
Expand Down
36 changes: 25 additions & 11 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from ._common import EventDataBatch, EventData

if TYPE_CHECKING:
from azure.core.credentials import TokenCredential, AzureSasCredential, AzureNamedKeyCredential
from azure.core.credentials import (
TokenCredential,
AzureSasCredential,
AzureNamedKeyCredential,
)

SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]]

Expand Down Expand Up @@ -143,7 +147,10 @@ def _start_producer(self, partition_id, send_timeout):
or cast(EventHubProducer, self._producers[partition_id]).closed
):
self._producers[partition_id] = self._create_producer(
partition_id=partition_id, send_timeout=send_timeout
partition_id=(
None if partition_id == ALL_PARTITIONS else partition_id
),
send_timeout=send_timeout,
)

def _create_producer(self, partition_id=None, send_timeout=None):
Expand Down Expand Up @@ -261,14 +268,21 @@ def send_batch(self, event_data_batch, **kwargs):

if isinstance(event_data_batch, EventDataBatch):
if partition_id or partition_key:
raise TypeError("partition_id and partition_key should be None when sending an EventDataBatch "
"because type EventDataBatch itself may have partition_id or partition_key")
raise TypeError(
"partition_id and partition_key should be None when sending an EventDataBatch "
"because type EventDataBatch itself may have partition_id or partition_key"
)
to_send_batch = event_data_batch
else:
to_send_batch = self.create_batch(partition_id=partition_id, partition_key=partition_key)
to_send_batch._load_events(event_data_batch) # pylint:disable=protected-access
to_send_batch = self.create_batch(
partition_id=partition_id, partition_key=partition_key
)
to_send_batch._load_events( # pylint:disable=protected-access
event_data_batch
)
partition_id = (
to_send_batch._partition_id or ALL_PARTITIONS # pylint:disable=protected-access
to_send_batch._partition_id # pylint:disable=protected-access
or ALL_PARTITIONS
)

if len(to_send_batch) == 0:
Expand Down Expand Up @@ -400,8 +414,8 @@ def close(self):
"""
with self._lock:
for producer in self._producers.values():
if producer:
producer.close()
self._producers = {}
for pid in self._producers:
if self._producers[pid]:
self._producers[pid].close() # type: ignore
self._producers[pid] = None
super(EventHubProducerClient, self)._close()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

if TYPE_CHECKING:
from azure.core.credentials_async import AsyncTokenCredential
from uamqp.constants import TransportType # pylint: disable=ungrouped-imports
from uamqp.constants import TransportType # pylint: disable=ungrouped-imports

SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]]

Expand Down Expand Up @@ -80,7 +80,9 @@ def __init__(
self,
fully_qualified_namespace: str,
eventhub_name: str,
credential: Union["AsyncTokenCredential", AzureSasCredential, AzureNamedKeyCredential],
credential: Union[
"AsyncTokenCredential", AzureSasCredential, AzureNamedKeyCredential
],
**kwargs
) -> None:
super(EventHubProducerClient, self).__init__(
Expand Down Expand Up @@ -145,7 +147,10 @@ async def _start_producer(
or cast(EventHubProducer, self._producers[partition_id]).closed
):
self._producers[partition_id] = self._create_producer(
partition_id=partition_id, send_timeout=send_timeout
partition_id=(
None if partition_id == ALL_PARTITIONS else partition_id
),
send_timeout=send_timeout,
)

def _create_producer(
Expand Down Expand Up @@ -294,18 +299,25 @@ async def send_batch(

if isinstance(event_data_batch, EventDataBatch):
if partition_id or partition_key:
raise TypeError("partition_id and partition_key should be None when sending an EventDataBatch "
"because type EventDataBatch itself may have partition_id or partition_key")
raise TypeError(
"partition_id and partition_key should be None when sending an EventDataBatch "
"because type EventDataBatch itself may have partition_id or partition_key"
)
to_send_batch = event_data_batch
else:
to_send_batch = await self.create_batch(partition_id=partition_id, partition_key=partition_key)
to_send_batch._load_events(event_data_batch) # pylint:disable=protected-access
to_send_batch = await self.create_batch(
partition_id=partition_id, partition_key=partition_key
)
to_send_batch._load_events( # pylint:disable=protected-access
event_data_batch
)

if len(to_send_batch) == 0:
return

partition_id = (
to_send_batch._partition_id or ALL_PARTITIONS # pylint:disable=protected-access
to_send_batch._partition_id # pylint:disable=protected-access
or ALL_PARTITIONS
)
try:
await cast(EventHubProducer, self._producers[partition_id]).send(
Expand Down Expand Up @@ -431,7 +443,9 @@ async def close(self) -> None:
"""
async with self._lock:
for producer in self._producers.values():
if producer:
await producer.close()
for pid in self._producers:
if self._producers[pid] is not None:
await self._producers[pid].close() # type: ignore
self._producers[pid] = None

await super(EventHubProducerClient, self)._close_async()
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,34 @@ async def test_send_and_receive_small_body_async(connstr_receivers, payload):
async def test_send_partition_async(connstr_receivers):
connection_str, receivers = connstr_receivers
client = EventHubProducerClient.from_connection_string(connection_str)

async with client:
batch = await client.create_batch()
batch.add(EventData(b"Data"))
await client.send_batch(batch)

async with client:
batch = await client.create_batch(partition_id="1")
batch.add(EventData(b"Data"))
await client.send_batch(batch)

partition_0 = receivers[0].receive_message_batch(timeout=5000)
assert len(partition_0) == 0
partition_1 = receivers[1].receive_message_batch(timeout=5000)
assert len(partition_1) == 1
assert len(partition_0) + len(partition_1) == 2

async with client:
batch = await client.create_batch()
batch.add(EventData(b"Data"))
await client.send_batch(batch)

async with client:
batch = await client.create_batch(partition_id="1")
batch.add(EventData(b"Data"))
await client.send_batch(batch)

partition_0 = receivers[0].receive_message_batch(timeout=5000)
partition_1 = receivers[1].receive_message_batch(timeout=5000)
assert len(partition_0) + len(partition_1) == 2


@pytest.mark.liveTest
Expand Down
23 changes: 21 additions & 2 deletions sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,34 @@ def test_send_and_receive_small_body(connstr_receivers, payload):
def test_send_partition(connstr_receivers):
connection_str, receivers = connstr_receivers
client = EventHubProducerClient.from_connection_string(connection_str)

with client:
batch = client.create_batch()
batch.add(EventData(b"Data"))
client.send_batch(batch)

with client:
batch = client.create_batch(partition_id="1")
batch.add(EventData(b"Data"))
client.send_batch(batch)

partition_0 = receivers[0].receive_message_batch(timeout=5000)
assert len(partition_0) == 0
partition_1 = receivers[1].receive_message_batch(timeout=5000)
assert len(partition_1) == 1
assert len(partition_0) + len(partition_1) == 2

with client:
batch = client.create_batch()
batch.add(EventData(b"Data"))
client.send_batch(batch)

with client:
batch = client.create_batch(partition_id="1")
batch.add(EventData(b"Data"))
client.send_batch(batch)

partition_0 = receivers[0].receive_message_batch(timeout=5000)
partition_1 = receivers[1].receive_message_batch(timeout=5000)
assert len(partition_0) + len(partition_1) == 2


@pytest.mark.liveTest
Expand Down

0 comments on commit 811cf03

Please sign in to comment.