diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index 1a15dab8841dd..abeb952b47f69 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -6,6 +6,8 @@ ### Breaking Changes +- Fixed a bug in `BufferedProducer` that would block when flushing the queue causing the client to freeze up (issue #23510). + ### Bugs Fixed ### Other Changes diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py index ff18a87921bdb..d61ba8089e0a3 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_buffered_producer/_buffered_producer.py @@ -96,8 +96,7 @@ def put_events(self, events, timeout_time=None): new_events_len, ) # flush the buffer - with self._lock: - self.flush(timeout_time=timeout_time) + self.flush(timeout_time=timeout_time) if timeout_time and time.time() > timeout_time: raise OperationTimeoutError( "Failed to enqueue events into buffer due to timeout." @@ -107,14 +106,16 @@ def put_events(self, events, timeout_time=None): self._cur_batch.add(events) except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer # if there are events in cur_batch, enqueue cur_batch to the buffer - if self._cur_batch: - self._buffered_queue.put(self._cur_batch) - self._buffered_queue.put(events) + with self._lock: + if self._cur_batch: + self._buffered_queue.put(self._cur_batch) + self._buffered_queue.put(events) # create a new batch for incoming events self._cur_batch = EventDataBatch(self._max_message_size_on_link) except ValueError: # add single event exceeds the cur batch size, create new batch - self._buffered_queue.put(self._cur_batch) + with self._lock: + self._buffered_queue.put(self._cur_batch) self._cur_batch = EventDataBatch(self._max_message_size_on_link) self._cur_batch.add(events) self._cur_buffered_len += new_events_len @@ -140,10 +141,13 @@ def flush(self, timeout_time=None, raise_error=True): _LOGGER.info("Partition: %r started flushing.", self.partition_id) if self._cur_batch: # if there is batch, enqueue it to the buffer first self._buffered_queue.put(self._cur_batch) - while self._cur_buffered_len: + while self._buffered_queue.qsize() > 0: remaining_time = timeout_time - time.time() if timeout_time else None if (remaining_time and remaining_time > 0) or remaining_time is None: - batch = self._buffered_queue.get() + try: + batch = self._buffered_queue.get(block=False) + except queue.Empty: + break self._buffered_queue.task_done() try: _LOGGER.info("Partition %r is sending.", self.partition_id) @@ -182,6 +186,8 @@ def flush(self, timeout_time=None, raise_error=True): break # after finishing flushing, reset cur batch and put it into the buffer self._last_send_time = time.time() + #reset buffered count + self._cur_buffered_len = 0 self._cur_batch = EventDataBatch(self._max_message_size_on_link) _LOGGER.info("Partition %r finished flushing.", self.partition_id) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py index 2d98878d51464..17e5be1fde828 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_buffered_producer/_buffered_producer_async.py @@ -89,7 +89,6 @@ async def put_events(self, events, timeout_time=None): new_events_len = len(events) except TypeError: new_events_len = 1 - if self._max_buffer_len - self._cur_buffered_len < new_events_len: _LOGGER.info( "The buffer for partition %r is full. Attempting to flush before adding %r events.", @@ -97,9 +96,7 @@ async def put_events(self, events, timeout_time=None): new_events_len, ) # flush the buffer - async with self._lock: - await self._flush(timeout_time=timeout_time) - + await self.flush(timeout_time=timeout_time) if timeout_time and time.time() > timeout_time: raise OperationTimeoutError( "Failed to enqueue events into buffer due to timeout." @@ -109,14 +106,16 @@ async def put_events(self, events, timeout_time=None): self._cur_batch.add(events) except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer # if there are events in cur_batch, enqueue cur_batch to the buffer - if self._cur_batch: - self._buffered_queue.put(self._cur_batch) - self._buffered_queue.put(events) + async with self._lock: + if self._cur_batch: + self._buffered_queue.put(self._cur_batch) + self._buffered_queue.put(events) # create a new batch for incoming events self._cur_batch = EventDataBatch(self._max_message_size_on_link) except ValueError: # add single event exceeds the cur batch size, create new batch - self._buffered_queue.put(self._cur_batch) + async with self._lock: + self._buffered_queue.put(self._cur_batch) self._cur_batch = EventDataBatch(self._max_message_size_on_link) self._cur_batch.add(events) self._cur_buffered_len += new_events_len @@ -146,10 +145,13 @@ async def _flush(self, timeout_time=None, raise_error=True): if self._cur_batch: # if there is batch, enqueue it to the buffer first self._buffered_queue.put(self._cur_batch) self._cur_batch = EventDataBatch(self._max_message_size_on_link) - while self._cur_buffered_len: + while self._buffered_queue.qsize() > 0: remaining_time = timeout_time - time.time() if timeout_time else None if (remaining_time and remaining_time > 0) or remaining_time is None: - batch = self._buffered_queue.get() + try: + batch = self._buffered_queue.get(block=False) + except queue.Empty: + break self._buffered_queue.task_done() try: _LOGGER.info("Partition %r is sending.", self.partition_id) @@ -187,6 +189,8 @@ async def _flush(self, timeout_time=None, raise_error=True): break # after finishing flushing, reset cur batch and put it into the buffer self._last_send_time = time.time() + #reset curr_buffered + self._cur_buffered_len = 0 self._cur_batch = EventDataBatch(self._max_message_size_on_link) _LOGGER.info("Partition %r finished flushing.", self.partition_id) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_buffered_producer_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_buffered_producer_async.py index 3937a686d0fb3..411a946f328d0 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_buffered_producer_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_buffered_producer_async.py @@ -490,3 +490,50 @@ async def on_error(events, pid, err): await consumer.close() await receive_thread + +@pytest.mark.liveTest +@pytest.mark.asyncio +async def test_long_wait_small_buffer(connection_str): + received_events = defaultdict(list) + + async def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + + receive_thread = asyncio.ensure_future(consumer.receive(on_event=on_event)) + + sent_events = defaultdict(list) + + async def on_success(events, pid): + sent_events[pid].extend(events) + + async def on_error(events, pid, err): + on_error.err = err + + on_error.err = None # ensure no error + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error, + auth_timeout=3, + retry_total=3, + retry_mode='fixed', + retry_backoff_factor=0.01, + max_wait_time=10, + max_buffer_length=100 + ) + + async with producer: + for i in range(100): + await producer.send_event(EventData("test")) + + await asyncio.sleep(60) + + assert not on_error.err + assert sum([len(sent_events[key]) for key in sent_events]) == 100 + assert sum([len(received_events[key]) for key in received_events]) == 100 + + await consumer.close() + await receive_thread diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py index 1360e699b8bfe..70dea52116dba 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_buffered_producer.py @@ -496,3 +496,51 @@ def on_error(events, pid, err): consumer.close() receive_thread.join() + +@pytest.mark.liveTest +def test_long_wait_small_buffer(connection_str): + received_events = defaultdict(list) + + def on_event(partition_context, event): + received_events[partition_context.partition_id].append(event) + + consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default") + receive_thread = Thread(target=consumer.receive, args=(on_event,)) + receive_thread.daemon = True + receive_thread.start() + + sent_events = defaultdict(list) + + def on_success(events, pid): + sent_events[pid].extend(events) + + def on_error(events, pid, err): + on_error.err = err + + on_error.err = None # ensure no error + producer = EventHubProducerClient.from_connection_string( + connection_str, + buffered_mode=True, + on_success=on_success, + on_error=on_error, + auth_timeout=3, + retry_total=3, + retry_mode='fixed', + retry_backoff_factor=0.01, + max_wait_time=10, + max_buffer_length=100 + ) + + with producer: + for i in range(100): + producer.send_event(EventData("test")) + time.sleep(.1) + + time.sleep(60) + + assert not on_error.err + assert sum([len(sent_events[key]) for key in sent_events]) == 100 + assert sum([len(received_events[key]) for key in received_events]) == 100 + + consumer.close() + receive_thread.join()