diff --git a/sdk/core/azure-core-amqp/src/main/java/com/azure/core/amqp/implementation/ReactorReceiver.java b/sdk/core/azure-core-amqp/src/main/java/com/azure/core/amqp/implementation/ReactorReceiver.java index e6418bbe85404..6e29839f2626d 100644 --- a/sdk/core/azure-core-amqp/src/main/java/com/azure/core/amqp/implementation/ReactorReceiver.java +++ b/sdk/core/azure-core-amqp/src/main/java/com/azure/core/amqp/implementation/ReactorReceiver.java @@ -55,21 +55,20 @@ protected ReactorReceiver(String entityPath, Receiver receiver, ReceiveLinkHandl this.dispatcher = dispatcher; this.messagesProcessor = this.handler.getDeliveredMessages() .map(delivery -> decodeDelivery(delivery)) - .subscribeWith(EmitterProcessor.create()); - - this.messagesProcessor.doOnNext(next -> { - if (receiver.getRemoteCredit() == 0) { - final Supplier supplier = creditSupplier.get(); - if (supplier == null) { - return; + .doOnNext(next -> { + if (receiver.getRemoteCredit() == 0) { + final Supplier supplier = creditSupplier.get(); + if (supplier == null) { + return; + } + + final Integer credits = supplier.get(); + if (credits != null && credits > 0) { + addCredits(credits); + } } - - final Integer credits = supplier.get(); - if (credits != null && credits > 0) { - addCredits(credits); - } - } - }); + }) + .subscribeWith(EmitterProcessor.create()); this.subscriptions = Disposables.composite( this.handler.getEndpointStates().subscribe( diff --git a/sdk/core/azure-core-amqp/src/test/java/com/azure/core/amqp/implementation/ReactorReceiverTest.java b/sdk/core/azure-core-amqp/src/test/java/com/azure/core/amqp/implementation/ReactorReceiverTest.java index 0199deb862311..cfbc5b53546c0 100644 --- a/sdk/core/azure-core-amqp/src/test/java/com/azure/core/amqp/implementation/ReactorReceiverTest.java +++ b/sdk/core/azure-core-amqp/src/test/java/com/azure/core/amqp/implementation/ReactorReceiverTest.java @@ -4,12 +4,14 @@ package com.azure.core.amqp.implementation; import com.azure.core.amqp.AmqpEndpointState; +import com.azure.core.amqp.AmqpMessageConstant; import com.azure.core.amqp.ClaimsBasedSecurityNode; import com.azure.core.amqp.exception.AmqpErrorCondition; import com.azure.core.amqp.implementation.handler.ReceiveLinkHandler; import org.apache.qpid.proton.amqp.Symbol; import org.apache.qpid.proton.amqp.messaging.Source; import org.apache.qpid.proton.amqp.transport.ErrorCondition; +import org.apache.qpid.proton.engine.Delivery; import org.apache.qpid.proton.engine.EndpointState; import org.apache.qpid.proton.engine.Event; import org.apache.qpid.proton.engine.Link; @@ -35,9 +37,12 @@ import java.io.IOException; import java.time.Duration; import java.util.List; +import java.util.Map; +import java.util.function.Supplier; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -59,6 +64,8 @@ class ReactorReceiverTest { private Record record; @Mock private ReactorDispatcher dispatcher; + @Mock + private Supplier creditSupplier; @Captor private ArgumentCaptor dispatcherCaptor; @@ -77,7 +84,7 @@ static void afterAll() { } @BeforeEach - void setup() throws IOException { + void setup() { MockitoAnnotations.initMocks(this); when(cbsNode.authorize(any(), any())).thenReturn(Mono.empty()); @@ -120,7 +127,7 @@ void addCredits() throws IOException { final List invocations = dispatcherCaptor.getAllValues(); assertEquals(1, invocations.size()); - // Apply the invocation. This should actually set the credits. + // Apply the invocation. invocations.get(0).run(); verify(receiver).flow(credits); @@ -195,4 +202,62 @@ void closesOnNonAmqpException() { .expectComplete() .verify(Duration.ofSeconds(10)); } + + @Test + void addsMoreCreditsWhenPrefetchIsDone() throws IOException { + // Arrange + // This message was copied from one that was received. + final byte[] messageBytes = new byte[] { 0, 83, 114, -63, 73, 6, -93, 21, 120, 45, 111, 112, 116, 45, 115, 101, + 113, 117, 101, 110, 99, 101, 45, 110, 117, 109, 98, 101, 114, 85, 0, -93, 12, 120, 45, 111, 112, 116, 45, + 111, 102, 102, 115, 101, 116, -95, 1, 48, -93, 19, 120, 45, 111, 112, 116, 45, 101, 110, 113, 117, 101, 117, + 101, 100, 45, 116, 105, 109, 101, -125, 0, 0, 1, 112, -54, 124, -41, 90, 0, 83, 117, -96, 12, 80, 111, 115, + 105, 116, 105, 111, 110, 53, 58, 32, 48}; + final Link link = mock(Link.class); + final Delivery delivery = mock(Delivery.class); + + when(event.getLink()).thenReturn(link); + when(event.getDelivery()).thenReturn(delivery); + + when(delivery.getLink()).thenReturn(receiver); + when(delivery.isPartial()).thenReturn(false); + when(delivery.isSettled()).thenReturn(false); + when(delivery.pending()).thenReturn(messageBytes.length); + + when(receiver.getRemoteCredit()).thenReturn(0); + when(receiver.recv(any(), eq(0), eq(messageBytes.length))).thenAnswer(invocation -> { + final byte[] buffer = invocation.getArgument(0); + System.arraycopy(messageBytes, 0, buffer, 0, messageBytes.length); + return messageBytes.length; + }); + + when(creditSupplier.get()).thenReturn(10); + reactorReceiver.setEmptyCreditListener(creditSupplier); + + // Act & Assert + StepVerifier.create(reactorReceiver.receive()) + .then(() -> receiverHandler.onDelivery(event)) + .assertNext(message -> { + Assertions.assertNotNull(message.getMessageAnnotations()); + + final Map values = message.getMessageAnnotations().getValue(); + Assertions.assertTrue(values.containsKey(Symbol.getSymbol(AmqpMessageConstant.OFFSET_ANNOTATION_NAME.getValue()))); + Assertions.assertTrue(values.containsKey(Symbol.getSymbol(AmqpMessageConstant.SEQUENCE_NUMBER_ANNOTATION_NAME.getValue()))); + Assertions.assertTrue(values.containsKey(Symbol.getSymbol(AmqpMessageConstant.ENQUEUED_TIME_UTC_ANNOTATION_NAME.getValue()))); + }) + .thenCancel() + .verify(); + + verify(creditSupplier).get(); + + // Verify that the get addCredits was called on that dispatcher. + verify(dispatcher).invoke(dispatcherCaptor.capture()); + + final List invocations = dispatcherCaptor.getAllValues(); + assertEquals(1, invocations.size()); + + // Apply the invocation. + invocations.get(0).run(); + + verify(receiver).flow(10); + } }