Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Eventhub] Fix Blocking Behavior of Buffered Producer Flush #25406

Merged
merged 17 commits into from
Aug 12, 2022
2 changes: 2 additions & 0 deletions sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Breaking Changes

- Fixed a bug in `BufferedProducer` that would block when flushing the queue causing the client to freeze up (issue #23510).

### Bugs Fixed

### Other Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def put_events(self, events, timeout_time=None):
new_events_len,
)
# flush the buffer
with self._lock:
self.flush(timeout_time=timeout_time)
self.flush(timeout_time=timeout_time)
if timeout_time and time.time() > timeout_time:
raise OperationTimeoutError(
"Failed to enqueue events into buffer due to timeout."
Expand All @@ -107,14 +106,16 @@ def put_events(self, events, timeout_time=None):
self._cur_batch.add(events)
except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer
# if there are events in cur_batch, enqueue cur_batch to the buffer
if self._cur_batch:
self._buffered_queue.put(self._cur_batch)
self._buffered_queue.put(events)
with self._lock:
if self._cur_batch:
self._buffered_queue.put(self._cur_batch)
self._buffered_queue.put(events)
# create a new batch for incoming events
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
except ValueError:
# add single event exceeds the cur batch size, create new batch
self._buffered_queue.put(self._cur_batch)
with self._lock:
self._buffered_queue.put(self._cur_batch)
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
self._cur_batch.add(events)
self._cur_buffered_len += new_events_len
Expand All @@ -140,10 +141,13 @@ def flush(self, timeout_time=None, raise_error=True):
_LOGGER.info("Partition: %r started flushing.", self.partition_id)
if self._cur_batch: # if there is batch, enqueue it to the buffer first
self._buffered_queue.put(self._cur_batch)
while self._cur_buffered_len:
while self._buffered_queue.qsize() > 0:
remaining_time = timeout_time - time.time() if timeout_time else None
if (remaining_time and remaining_time > 0) or remaining_time is None:
batch = self._buffered_queue.get()
try:
batch = self._buffered_queue.get(block=False)
except queue.Empty:
break
self._buffered_queue.task_done()
try:
_LOGGER.info("Partition %r is sending.", self.partition_id)
Expand Down Expand Up @@ -182,6 +186,8 @@ def flush(self, timeout_time=None, raise_error=True):
break
# after finishing flushing, reset cur batch and put it into the buffer
self._last_send_time = time.time()
#reset buffered count
self._cur_buffered_len = 0
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
_LOGGER.info("Partition %r finished flushing.", self.partition_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,14 @@ async def put_events(self, events, timeout_time=None):
new_events_len = len(events)
except TypeError:
new_events_len = 1

if self._max_buffer_len - self._cur_buffered_len < new_events_len:
_LOGGER.info(
"The buffer for partition %r is full. Attempting to flush before adding %r events.",
self.partition_id,
new_events_len,
)
# flush the buffer
async with self._lock:
await self._flush(timeout_time=timeout_time)

await self.flush(timeout_time=timeout_time)
if timeout_time and time.time() > timeout_time:
raise OperationTimeoutError(
"Failed to enqueue events into buffer due to timeout."
Expand All @@ -109,14 +106,16 @@ async def put_events(self, events, timeout_time=None):
self._cur_batch.add(events)
except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer
# if there are events in cur_batch, enqueue cur_batch to the buffer
if self._cur_batch:
self._buffered_queue.put(self._cur_batch)
self._buffered_queue.put(events)
async with self._lock:
if self._cur_batch:
self._buffered_queue.put(self._cur_batch)
self._buffered_queue.put(events)
# create a new batch for incoming events
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
except ValueError:
# add single event exceeds the cur batch size, create new batch
self._buffered_queue.put(self._cur_batch)
async with self._lock:
self._buffered_queue.put(self._cur_batch)
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
self._cur_batch.add(events)
self._cur_buffered_len += new_events_len
Expand Down Expand Up @@ -146,10 +145,13 @@ async def _flush(self, timeout_time=None, raise_error=True):
if self._cur_batch: # if there is batch, enqueue it to the buffer first
self._buffered_queue.put(self._cur_batch)
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
while self._cur_buffered_len:
while self._buffered_queue.qsize() > 0:
remaining_time = timeout_time - time.time() if timeout_time else None
if (remaining_time and remaining_time > 0) or remaining_time is None:
batch = self._buffered_queue.get()
try:
batch = self._buffered_queue.get(block=False)
except queue.Empty:
break
self._buffered_queue.task_done()
try:
_LOGGER.info("Partition %r is sending.", self.partition_id)
Expand Down Expand Up @@ -187,6 +189,8 @@ async def _flush(self, timeout_time=None, raise_error=True):
break
# after finishing flushing, reset cur batch and put it into the buffer
self._last_send_time = time.time()
#reset curr_buffered
self._cur_buffered_len = 0
self._cur_batch = EventDataBatch(self._max_message_size_on_link)
_LOGGER.info("Partition %r finished flushing.", self.partition_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,50 @@ async def on_error(events, pid, err):

await consumer.close()
await receive_thread

@pytest.mark.liveTest
@pytest.mark.asyncio
async def test_long_wait_small_buffer(connection_str):
kashifkhan marked this conversation as resolved.
Show resolved Hide resolved
received_events = defaultdict(list)

async def on_event(partition_context, event):
received_events[partition_context.partition_id].append(event)

consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default")

receive_thread = asyncio.ensure_future(consumer.receive(on_event=on_event))

sent_events = defaultdict(list)

async def on_success(events, pid):
sent_events[pid].extend(events)

async def on_error(events, pid, err):
on_error.err = err

on_error.err = None # ensure no error
producer = EventHubProducerClient.from_connection_string(
connection_str,
buffered_mode=True,
on_success=on_success,
on_error=on_error,
auth_timeout=3,
retry_total=3,
retry_mode='fixed',
retry_backoff_factor=0.01,
max_wait_time=10,
max_buffer_length=100
)

async with producer:
for i in range(100):
await producer.send_event(EventData("test"))

await asyncio.sleep(60)

assert not on_error.err
assert sum([len(sent_events[key]) for key in sent_events]) == 100
assert sum([len(received_events[key]) for key in received_events]) == 100

await consumer.close()
await receive_thread
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,51 @@ def on_error(events, pid, err):

consumer.close()
receive_thread.join()

@pytest.mark.liveTest
def test_long_wait_small_buffer(connection_str):
received_events = defaultdict(list)

def on_event(partition_context, event):
received_events[partition_context.partition_id].append(event)

consumer = EventHubConsumerClient.from_connection_string(connection_str, consumer_group="$default")
receive_thread = Thread(target=consumer.receive, args=(on_event,))
receive_thread.daemon = True
receive_thread.start()

sent_events = defaultdict(list)

def on_success(events, pid):
sent_events[pid].extend(events)

def on_error(events, pid, err):
on_error.err = err

on_error.err = None # ensure no error
producer = EventHubProducerClient.from_connection_string(
connection_str,
buffered_mode=True,
on_success=on_success,
on_error=on_error,
auth_timeout=3,
retry_total=3,
retry_mode='fixed',
retry_backoff_factor=0.01,
max_wait_time=10,
max_buffer_length=100
)

with producer:
for i in range(100):
producer.send_event(EventData("test"))
time.sleep(.1)

time.sleep(60)

assert not on_error.err
assert sum([len(sent_events[key]) for key in sent_events]) == 100
assert sum([len(received_events[key]) for key in received_events]) == 100

consumer.close()
receive_thread.join()