From 3089e0b7b897d25ed278fc8341003639ecccc0b0 Mon Sep 17 00:00:00 2001 From: Sakthivel Subramanian Date: Mon, 4 Nov 2024 07:11:43 +0530 Subject: [PATCH] fix: Avoid blocking thread in AsyncResultSet --- .../clirr-ignored-differences.xml | 6 ++- .../cloud/spanner/AbstractReadContext.java | 10 +++-- .../google/cloud/spanner/AsyncResultSet.java | 15 ++++--- .../cloud/spanner/AsyncResultSetImpl.java | 45 ++++++++++++------- .../cloud/spanner/ForwardingResultSet.java | 2 +- .../cloud/spanner/GrpcStreamIterator.java | 9 ++-- .../com/google/cloud/spanner/ResultSet.java | 4 ++ .../spanner/ResumableStreamIterator.java | 29 ++++++------ .../com/google/cloud/spanner/SessionPool.java | 3 +- .../cloud/spanner/AsyncResultSetImplTest.java | 13 ++++-- .../spanner/ResumableStreamIteratorTest.java | 6 ++- 11 files changed, 87 insertions(+), 55 deletions(-) diff --git a/google-cloud-spanner/clirr-ignored-differences.xml b/google-cloud-spanner/clirr-ignored-differences.xml index ec13415790..a9d8298831 100644 --- a/google-cloud-spanner/clirr-ignored-differences.xml +++ b/google-cloud-spanner/clirr-ignored-differences.xml @@ -790,5 +790,9 @@ com/google/cloud/spanner/connection/Connection boolean isAutoBatchDmlUpdateCountVerification() - + + 7012 + com/google/cloud/spanner/ResultSet + boolean initiateStreaming(com.google.cloud.spanner.AsyncResultSet$StreamMessageListener) + diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index bb73f91b8a..3cb7df74c0 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -768,8 +768,9 @@ ResultSet executeQueryInternalWithOptions( rpc.getExecuteQueryRetrySettings(), rpc.getExecuteQueryRetryableCodes()) { @Override - CloseableIterator startStream(@Nullable ByteString resumeToken, - AsyncResultSet.StreamMessageListener streamListener) { + CloseableIterator startStream( + @Nullable ByteString resumeToken, + AsyncResultSet.StreamMessageListener streamListener) { GrpcStreamIterator stream = new GrpcStreamIterator(statement, prefetchChunks, cancelQueryWhenClientIsClosed); stream.registerListener(streamListener); @@ -961,8 +962,9 @@ ResultSet readInternalWithOptions( rpc.getReadRetrySettings(), rpc.getReadRetryableCodes()) { @Override - CloseableIterator startStream(@Nullable ByteString resumeToken, - AsyncResultSet.StreamMessageListener streamListener) { + CloseableIterator startStream( + @Nullable ByteString resumeToken, + AsyncResultSet.StreamMessageListener streamListener) { GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks, cancelQueryWhenClientIsClosed); stream.registerListener(streamListener); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java index 21b341d108..dbd7f93a3c 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java @@ -19,7 +19,6 @@ import com.google.api.core.ApiFuture; import com.google.common.base.Function; import com.google.spanner.v1.PartialResultSet; - import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -227,16 +226,20 @@ interface ReadyCallback { List toList(Function transformer) throws SpannerException; /** - * An interface to register the listener for streaming gRPC request. It will be called when a chunk is received - * from gRPC streaming call. + * An interface to register the listener for streaming gRPC request. It will be called when a + * chunk is received from gRPC streaming call. */ interface StreamMessageListener { - void onStreamMessage(PartialResultSet partialResultSet, int prefetchChunks, int currentBufferSize, StreamMessageRequestor streamMessageRequestor); + void onStreamMessage( + PartialResultSet partialResultSet, + int prefetchChunks, + int currentBufferSize, + StreamMessageRequestor streamMessageRequestor); } /** - * An interface to request more messages from the gRPC streaming call. It will be implemented by the class which has access - * to SpannerRpc.StreamingCall object + * An interface to request more messages from the gRPC streaming call. It will be implemented by + * the class which has access to SpannerRpc.StreamingCall object */ interface StreamMessageRequestor { void requestMessages(int numOfMessages); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java index 28add6e34b..75015f7ea5 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java @@ -31,7 +31,6 @@ import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; - import java.util.Collection; import java.util.LinkedList; import java.util.List; @@ -40,10 +39,10 @@ import java.util.logging.Logger; /** Default implementation for {@link AsyncResultSet}. */ -class AsyncResultSetImpl extends ForwardingStructReader implements ListenableAsyncResultSet, AsyncResultSet.StreamMessageListener { +class AsyncResultSetImpl extends ForwardingStructReader + implements ListenableAsyncResultSet, AsyncResultSet.StreamMessageListener { private static final Logger log = Logger.getLogger(AsyncResultSetImpl.class.getName()); - /** State of an {@link AsyncResultSetImpl}. */ private enum State { INITIALIZED, @@ -112,6 +111,9 @@ private enum State { private State state = State.INITIALIZED; + /** This variable indicates that produce rows thread is initiated */ + private volatile boolean produceRowsInitiated; + /** * This variable indicates whether all the results from the underlying result set have been read. */ @@ -458,7 +460,7 @@ private class InitiateStreamingRunnable implements Runnable { @Override public void run() { try { - if(!initiateStreaming(AsyncResultSetImpl.this)) { + if (!initiateStreaming(AsyncResultSetImpl.this)) { initiateProduceRows(); } } catch (SpannerException e) { @@ -489,7 +491,10 @@ public ApiFuture setCallback(Executor exec, ReadyCallback cb) { private void initiateProduceRows() { this.service.execute(new ProduceRowsRunnable()); - this.state = State.RUNNING; + if (this.state == State.IN_PROGRESS) { + this.state = State.RUNNING; + } + produceRowsInitiated = true; } Future getResult() { @@ -504,7 +509,6 @@ public void cancel() { "cannot cancel a result set without a callback"); state = State.CANCELLED; pausedLatch.countDown(); - this.result.setException(CANCELLED_EXCEPTION); } } @@ -625,18 +629,25 @@ public Struct getCurrentRowAsStruct() { } @Override - public void onStreamMessage(PartialResultSet partialResultSet, int prefetchChunks, int currentBufferSize, StreamMessageRequestor streamMessageRequestor) { + public void onStreamMessage( + PartialResultSet partialResultSet, + int prefetchChunks, + int currentBufferSize, + StreamMessageRequestor streamMessageRequestor) { synchronized (monitor) { - if (state == State.IN_PROGRESS) { - // if PartialResultSet contains resume token or buffer size is more than configured size or we have reached - // end of stream, we can start the thread - boolean startJobThread = !partialResultSet.getResumeToken().isEmpty() - || currentBufferSize > prefetchChunks || partialResultSet == GrpcStreamIterator.END_OF_STREAM; - if (startJobThread){ - initiateProduceRows(); - } else { - streamMessageRequestor.requestMessages(1); - } + if (produceRowsInitiated) { + return; + } + // if PartialResultSet contains resume token or buffer size is more than configured size or + // we have reached end of stream, we can start the thread + boolean startJobThread = + !partialResultSet.getResumeToken().isEmpty() + || currentBufferSize > prefetchChunks + || partialResultSet == GrpcStreamIterator.END_OF_STREAM; + if (startJobThread || state != State.IN_PROGRESS) { + initiateProduceRows(); + } else { + streamMessageRequestor.requestMessages(1); } } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java index 24fbf353bb..5ed39a92ff 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java @@ -105,6 +105,6 @@ public ResultSetMetadata getMetadata() { @Override public boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { - return delegate.get().initiateStreaming(streamMessageListener); + return delegate.get().initiateStreaming(streamMessageListener); } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java index f3047b13da..b74bc77318 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java @@ -23,15 +23,14 @@ import com.google.common.collect.AbstractIterator; import com.google.common.util.concurrent.Uninterruptibles; import com.google.spanner.v1.PartialResultSet; -import org.threeten.bp.Duration; - -import javax.annotation.Nullable; import java.util.Optional; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; +import javax.annotation.Nullable; +import org.threeten.bp.Duration; /** Adapts a streaming read/query call into an iterator over partial result sets. */ @VisibleForTesting @@ -199,7 +198,7 @@ public boolean cancelQueryWhenClientIsClosed() { } private void onStreamMessage(PartialResultSet partialResultSet) { - Optional.ofNullable(streamMessageListener).ifPresent( - sl -> sl.onStreamMessage(partialResultSet, prefetchChunks, stream.size(), this)); + Optional.ofNullable(streamMessageListener) + .ifPresent(sl -> sl.onStreamMessage(partialResultSet, prefetchChunks, stream.size(), this)); } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSet.java index 9c0b2dc512..9a21e8d40d 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResultSet.java @@ -83,6 +83,10 @@ default ResultSetMetadata getMetadata() { throw new UnsupportedOperationException("Method should be overridden"); } + /** + * Returns the {@link boolean} for this {@link ResultSet}. This method will be used by + * AsyncResultSet to initiate gRPC streaming + */ default boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { return false; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java index 840df9e86b..b94bdaf848 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java @@ -16,6 +16,11 @@ package com.google.cloud.spanner; +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerExceptionForCancellation; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.api.client.util.BackOff; import com.google.api.client.util.ExponentialBackOff; import com.google.api.gax.grpc.GrpcStatusCode; @@ -30,8 +35,6 @@ import com.google.spanner.v1.PartialResultSet; import io.grpc.Context; import io.opentelemetry.api.common.Attributes; - -import javax.annotation.Nullable; import java.io.IOException; import java.util.LinkedList; import java.util.Objects; @@ -41,11 +44,7 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; - -import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; -import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerExceptionForCancellation; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; +import javax.annotation.Nullable; /** * Wraps an iterator over partial result sets, supporting resuming RPCs on error. This class keeps @@ -198,8 +197,8 @@ public void execute(Runnable command) { } } - abstract CloseableIterator startStream(@Nullable ByteString resumeToken, - AsyncResultSet.StreamMessageListener streamMessageListener); + abstract CloseableIterator startStream( + @Nullable ByteString resumeToken, AsyncResultSet.StreamMessageListener streamMessageListener); /** * Prepares the iterator for a retry on a different gRPC channel. Returns true if that is @@ -226,7 +225,7 @@ public boolean isWithBeginTransaction() { @Override public boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { this.streamMessageListener = streamMessageListener; - eagerStartStreaming(); + startGrpcStreaming(); return true; } @@ -236,7 +235,7 @@ protected PartialResultSet computeNext() { Context context = Context.current(); while (true) { // Eagerly start stream before consuming any buffered items. - eagerStartStreaming(); + startGrpcStreaming(); // Buffer contains items up to a resume token or has reached capacity: flush. if (!buffer.isEmpty() && (finished || !safeToRetry || !buffer.getLast().getResumeToken().isEmpty())) { @@ -315,12 +314,12 @@ && prepareIteratorForRetryOnDifferentGrpcChannel()) { } } - private void eagerStartStreaming() { + private void startGrpcStreaming() { if (stream == null) { span.addAnnotation( - "Starting/Resuming stream", - "ResumeToken", - resumeToken == null ? "null" : resumeToken.toStringUtf8()); + "Starting/Resuming stream", + "ResumeToken", + resumeToken == null ? "null" : resumeToken.toStringUtf8()); try (IScope scope = tracer.withSpan(span)) { // When start a new stream set the Span as current to make the gRPC Span a child of // this Span. diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java index 891cc6e4c7..f98d96eef1 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java @@ -289,7 +289,8 @@ public boolean next() throws SpannerException { } @Override - public boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { + public boolean initiateStreaming( + AsyncResultSet.StreamMessageListener streamMessageListener) { try { boolean ret = super.initiateStreaming(streamMessageListener); if (beforeFirst) { diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java index 1b8395c5b1..986b9d08f1 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -32,6 +33,7 @@ import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; import com.google.common.base.Function; import com.google.common.collect.Range; +import com.google.spanner.v1.PartialResultSet; import java.util.List; import java.util.concurrent.BlockingDeque; import java.util.concurrent.CountDownLatch; @@ -48,8 +50,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -390,6 +390,13 @@ public void testCallbackIsNotCalledWhilePausedAndCanceled() try (AsyncResultSetImpl rs = new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + + when(delegate.initiateStreaming(any(AsyncResultSet.StreamMessageListener.class))) + .thenAnswer( + answer -> { + rs.onStreamMessage(PartialResultSet.newBuilder().build(), 4, 1, null); + return null; + }); callbackResult = rs.setCallback( executor, @@ -402,7 +409,7 @@ public void testCallbackIsNotCalledWhilePausedAndCanceled() SpannerException exception = assertThrows(SpannerException.class, () -> get(callbackResult)); assertEquals(ErrorCode.CANCELLED, exception.getErrorCode()); - assertEquals(0, callbackCounter.get()); + assertEquals(1, callbackCounter.get()); } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java index 94797a1fcf..ebe8672467 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java @@ -64,7 +64,8 @@ public class ResumableStreamIteratorTest { interface Starter { AbstractResultSet.CloseableIterator startStream( - @Nullable ByteString resumeToken, AsyncResultSet.StreamMessageListener streamMessageListener); + @Nullable ByteString resumeToken, + AsyncResultSet.StreamMessageListener streamMessageListener); } interface ResultSetStream { @@ -164,7 +165,8 @@ private void initWithLimit(int maxBufferSize) { SpannerStubSettings.newBuilder().executeStreamingSqlSettings().getRetryableCodes()) { @Override AbstractResultSet.CloseableIterator startStream( - @Nullable ByteString resumeToken, AsyncResultSet.StreamMessageListener streamMessageListener) { + @Nullable ByteString resumeToken, + AsyncResultSet.StreamMessageListener streamMessageListener) { return starter.startStream(resumeToken, null); } };