From 61fd0b9b777eb69f29cde4a80471f3ee8bd2a723 Mon Sep 17 00:00:00 2001 From: zachjhum <123217734+zachjhum@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:19:50 -0800 Subject: [PATCH] Prevent improper error logging during worker shutdown (#1257) * Move throwOnIllegalState call to drain queue method to prevent improper error logging during worker shutdown * Fix unit tests that expected IllegalStateException thrown * Changed names of unit tests to reflect new behavior --- .../polling/PrefetchRecordsPublisher.java | 3 ++- .../polling/PrefetchRecordsPublisherTest.java | 25 ++++++++++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java index eb5937f70..1f7267b08 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisher.java @@ -289,7 +289,6 @@ private void throwOnIllegalState() { } private PrefetchRecordsRetrieved peekNextResult() { - throwOnIllegalState(); return publisherSession.peekNextRecord(); } @@ -336,6 +335,7 @@ public void restartFrom(RecordsRetrieved recordsRetrieved) { @Override public void subscribe(Subscriber s) { + throwOnIllegalState(); subscriber = s; subscriber.onSubscribe(new Subscription() { @Override @@ -389,6 +389,7 @@ synchronized void drainQueueForRequests() { // If there is an event available to drain and if there is at least one demand, // then schedule it for delivery if (publisherSession.hasDemandToPublish() && canDispatchRecord(recordsToDeliver)) { + throwOnIllegalState(); subscriber.onNext(recordsToDeliver.prepareForPublish()); recordsToDeliver.dispatched(); lastEventDeliveryTime = Instant.now(); diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java index 8d88151b1..aeacab8e2 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/retrieval/polling/PrefetchRecordsPublisherTest.java @@ -31,8 +31,10 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -64,6 +66,7 @@ import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; @@ -375,15 +378,29 @@ record = Record.builder().data(createByteBufferWithSize(1024)).build(); } @Test(expected = IllegalStateException.class) - public void testGetNextRecordsWithoutStarting() { + public void testSubscribeWithoutStarting() { verify(executorService, never()).execute(any()); - getRecordsCache.drainQueueForRequests(); + Subscriber mockSubscriber = mock(Subscriber.class); + getRecordsCache.subscribe(mockSubscriber); } @Test(expected = IllegalStateException.class) - public void testCallAfterShutdown() { + public void testRequestRecordsOnSubscriptionAfterShutdown() { + GetRecordsResponse response = GetRecordsResponse.builder().records( + Record.builder().data(SdkBytes.fromByteArray(new byte[] { 1, 2, 3 })).sequenceNumber("123").build()) + .nextShardIterator(NEXT_SHARD_ITERATOR).build(); + when(getRecordsRetrievalStrategy.getRecords(anyInt())).thenReturn(response); + + getRecordsCache.start(sequenceNumber, initialPosition); + + verify(getRecordsRetrievalStrategy, timeout(100).atLeastOnce()).getRecords(anyInt()); + when(executorService.isShutdown()).thenReturn(true); - getRecordsCache.drainQueueForRequests(); + Subscriber mockSubscriber = mock(Subscriber.class); + getRecordsCache.subscribe(mockSubscriber); + ArgumentCaptor subscriptionCaptor = ArgumentCaptor.forClass(Subscription.class); + verify(mockSubscriber).onSubscribe(subscriptionCaptor.capture()); + subscriptionCaptor.getValue().request(1); } @Test