diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index a89090e34d..cecf462bd2 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -768,9 +768,14 @@ ResultSet executeQueryInternalWithOptions( rpc.getExecuteQueryRetrySettings(), rpc.getExecuteQueryRetryableCodes()) { @Override - CloseableIterator startStream(@Nullable ByteString resumeToken) { + CloseableIterator startStream( + @Nullable ByteString resumeToken, + AsyncResultSet.StreamMessageListener streamListener) { GrpcStreamIterator stream = new GrpcStreamIterator(statement, prefetchChunks, cancelQueryWhenClientIsClosed); + if (streamListener != null) { + stream.registerListener(streamListener); + } if (partitionToken != null) { request.setPartitionToken(partitionToken); } @@ -791,8 +796,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken getTransactionChannelHint(), isRouteToLeader()); session.markUsed(clock.instant()); - call.request(prefetchChunks); stream.setCall(call, request.getTransaction().hasBegin()); + call.request(prefetchChunks); return stream; } @@ -959,9 +964,14 @@ ResultSet readInternalWithOptions( rpc.getReadRetrySettings(), rpc.getReadRetryableCodes()) { @Override - CloseableIterator startStream(@Nullable ByteString resumeToken) { + CloseableIterator startStream( + @Nullable ByteString resumeToken, + AsyncResultSet.StreamMessageListener streamListener) { GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks, cancelQueryWhenClientIsClosed); + if (streamListener != null) { + stream.registerListener(streamListener); + } TransactionSelector selector = null; if (resumeToken != null) { builder.setResumeToken(resumeToken); @@ -980,8 +990,8 @@ CloseableIterator startStream(@Nullable ByteString resumeToken getTransactionChannelHint(), isRouteToLeader()); session.markUsed(clock.instant()); - call.request(prefetchChunks); stream.setCall(call, /* withBeginTransaction = */ builder.getTransaction().hasBegin()); + call.request(prefetchChunks); return stream; } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java index fdc0398d5f..3dca970f96 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractResultSet.java @@ -150,6 +150,14 @@ interface CloseableIterator extends Iterator { void close(@Nullable String message); boolean isWithBeginTransaction(); + + /** + * @param streamMessageListener A class object which implements StreamMessageListener + * @return true if streaming is supported by the iterator, otherwise false + */ + default boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { + return false; + } } static double valueProtoToFloat64(com.google.protobuf.Value proto) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java index dfedcc4f8b..2b3225bfc5 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSet.java @@ -18,6 +18,7 @@ import com.google.api.core.ApiFuture; import com.google.common.base.Function; +import com.google.spanner.v1.PartialResultSet; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -223,4 +224,12 @@ interface ReadyCallback { * @param transformer function which will be used to transform the row. It should not return null. */ List toList(Function transformer) throws SpannerException; + + /** + * An interface to register the listener for streaming gRPC request. It will be called when a + * chunk is received from gRPC streaming call. + */ + interface StreamMessageListener { + void onStreamMessage(PartialResultSet partialResultSet, boolean bufferIsFull); + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java index fa7cc158c1..1161822cd1 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncResultSetImpl.java @@ -18,7 +18,6 @@ import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; -import com.google.api.core.ListenableFutureToApiFuture; import com.google.api.core.SettableApiFuture; import com.google.api.gax.core.ExecutorProvider; import com.google.cloud.spanner.AbstractReadContext.ListenableAsyncResultSet; @@ -29,13 +28,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; +import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; import java.util.Collection; import java.util.LinkedList; import java.util.List; import java.util.concurrent.BlockingDeque; -import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -45,12 +44,14 @@ import java.util.logging.Logger; /** Default implementation for {@link AsyncResultSet}. */ -class AsyncResultSetImpl extends ForwardingStructReader implements ListenableAsyncResultSet { +class AsyncResultSetImpl extends ForwardingStructReader + implements ListenableAsyncResultSet, AsyncResultSet.StreamMessageListener { private static final Logger log = Logger.getLogger(AsyncResultSetImpl.class.getName()); /** State of an {@link AsyncResultSetImpl}. */ private enum State { INITIALIZED, + STREAMING_INITIALIZED, /** SYNC indicates that the {@link ResultSet} is used in sync pattern. */ SYNC, CONSUMING, @@ -115,12 +116,15 @@ private enum State { private State state = State.INITIALIZED; + /** This variable indicates that produce rows thread is initiated */ + private volatile boolean produceRowsInitiated; + /** * This variable indicates whether all the results from the underlying result set have been read. */ private volatile boolean finished; - private volatile ApiFuture result; + private volatile SettableApiFuture result; /** * This variable indicates whether {@link #tryNext()} has returned {@link CursorState#DONE} or a @@ -329,12 +333,12 @@ public void run() { private final CallbackRunnable callbackRunnable = new CallbackRunnable(); /** - * {@link ProduceRowsCallable} reads data from the underlying {@link ResultSet}, places these in + * {@link ProduceRowsRunnable} reads data from the underlying {@link ResultSet}, places these in * the buffer and dispatches the {@link CallbackRunnable} when data is ready to be consumed. */ - private class ProduceRowsCallable implements Callable { + private class ProduceRowsRunnable implements Runnable { @Override - public Void call() throws Exception { + public void run() { boolean stop = false; boolean hasNext = false; try { @@ -393,12 +397,17 @@ public Void call() throws Exception { } // Call the callback if there are still rows in the buffer that need to be processed. while (!stop) { - waitIfPaused(); - startCallbackIfNecessary(); - // Make sure we wait until the callback runner has actually finished. - consumingLatch.await(); - synchronized (monitor) { - stop = cursorReturnedDoneOrException; + try { + waitIfPaused(); + startCallbackIfNecessary(); + // Make sure we wait until the callback runner has actually finished. + consumingLatch.await(); + synchronized (monitor) { + stop = cursorReturnedDoneOrException; + } + } catch (Throwable e) { + result.setException(e); + return; } } } finally { @@ -410,14 +419,14 @@ public Void call() throws Exception { } synchronized (monitor) { if (executionException != null) { - throw executionException; - } - if (state == State.CANCELLED) { - throw CANCELLED_EXCEPTION; + result.setException(executionException); + } else if (state == State.CANCELLED) { + result.setException(CANCELLED_EXCEPTION); + } else { + result.set(null); } } } - return null; } private void waitIfPaused() throws InterruptedException { @@ -449,6 +458,26 @@ private void startCallbackWithBufferLatchIfNecessary(int bufferLatch) { } } + private class InitiateStreamingRunnable implements Runnable { + + @Override + public void run() { + try { + // This method returns true if the underlying result set is a streaming result set (e.g. a + // GrpcResultSet). + // Those result sets will trigger initiateProduceRows() when the first results are received. + // Non-streaming result sets do not trigger this callback, and for those result sets, we + // need to eagerly start the ProduceRowsRunnable. + if (!initiateStreaming(AsyncResultSetImpl.this)) { + initiateProduceRows(); + } + } catch (Throwable exception) { + executionException = SpannerExceptionFactory.asSpannerException(exception); + initiateProduceRows(); + } + } + } + /** Sets the callback for this {@link AsyncResultSet}. */ @Override public ApiFuture setCallback(Executor exec, ReadyCallback cb) { @@ -458,16 +487,24 @@ public ApiFuture setCallback(Executor exec, ReadyCallback cb) { this.state == State.INITIALIZED, "callback may not be set multiple times"); // Start to fetch data and buffer these. - this.result = - new ListenableFutureToApiFuture<>(this.service.submit(new ProduceRowsCallable())); + this.result = SettableApiFuture.create(); + this.state = State.STREAMING_INITIALIZED; + this.service.execute(new InitiateStreamingRunnable()); this.executor = MoreExecutors.newSequentialExecutor(Preconditions.checkNotNull(exec)); this.callback = Preconditions.checkNotNull(cb); - this.state = State.RUNNING; pausedLatch.countDown(); return result; } } + private void initiateProduceRows() { + if (this.state == State.STREAMING_INITIALIZED) { + this.state = State.RUNNING; + } + produceRowsInitiated = true; + this.service.execute(new ProduceRowsRunnable()); + } + Future getResult() { return result; } @@ -578,6 +615,10 @@ public ResultSetMetadata getMetadata() { return delegateResultSet.get().getMetadata(); } + boolean initiateStreaming(StreamMessageListener streamMessageListener) { + return StreamingUtil.initiateStreaming(delegateResultSet.get(), streamMessageListener); + } + @Override protected void checkValidState() { synchronized (monitor) { @@ -593,4 +634,22 @@ public Struct getCurrentRowAsStruct() { checkValidState(); return currentRow; } + + @Override + public void onStreamMessage(PartialResultSet partialResultSet, boolean bufferIsFull) { + synchronized (monitor) { + if (produceRowsInitiated) { + return; + } + // if PartialResultSet contains a resume token or buffer size is full, or + // we have reached the end of the stream, we can start the thread. + boolean startJobThread = + !partialResultSet.getResumeToken().isEmpty() + || bufferIsFull + || partialResultSet == GrpcStreamIterator.END_OF_STREAM; + if (startJobThread || state != State.STREAMING_INITIALIZED) { + initiateProduceRows(); + } + } + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java index babbb310a4..3c4883e658 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ForwardingResultSet.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner; +import com.google.api.core.InternalApi; import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; @@ -23,7 +24,8 @@ import com.google.spanner.v1.ResultSetStats; /** Forwarding implementation of ResultSet that forwards all calls to a delegate. */ -public class ForwardingResultSet extends ForwardingStructReader implements ProtobufResultSet { +public class ForwardingResultSet extends ForwardingStructReader + implements ProtobufResultSet, StreamingResultSet { private Supplier delegate; @@ -102,4 +104,10 @@ public ResultSetStats getStats() { public ResultSetMetadata getMetadata() { return delegate.get().getMetadata(); } + + @Override + @InternalApi + public boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { + return StreamingUtil.initiateStreaming(delegate.get(), streamMessageListener); + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java index 23c9dd7c2d..c2a4ee5a58 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcResultSet.java @@ -19,6 +19,7 @@ import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; import static com.google.common.base.Preconditions.checkState; +import com.google.api.core.InternalApi; import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Value; import com.google.spanner.v1.PartialResultSet; @@ -30,7 +31,8 @@ import javax.annotation.Nullable; @VisibleForTesting -class GrpcResultSet extends AbstractResultSet> implements ProtobufResultSet { +class GrpcResultSet extends AbstractResultSet> + implements ProtobufResultSet, StreamingResultSet { private final GrpcValueIterator iterator; private final Listener listener; private final DecodeMode decodeMode; @@ -123,6 +125,12 @@ public ResultSetMetadata getMetadata() { return metadata; } + @Override + @InternalApi + public boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { + return iterator.initiateStreaming(streamMessageListener); + } + @Override public void close() { synchronized (this) { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java index af6b568350..79c02eab58 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java @@ -20,9 +20,11 @@ import com.google.cloud.spanner.AbstractResultSet.CloseableIterator; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.AbstractIterator; import com.google.common.util.concurrent.Uninterruptibles; import com.google.spanner.v1.PartialResultSet; +import java.util.Optional; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; @@ -36,7 +38,8 @@ class GrpcStreamIterator extends AbstractIterator implements CloseableIterator { private static final Logger logger = Logger.getLogger(GrpcStreamIterator.class.getName()); - private static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build(); + static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build(); + private AsyncResultSet.StreamMessageListener streamMessageListener; private final ConsumerImpl consumer; private final BlockingQueue stream; @@ -66,6 +69,10 @@ protected final SpannerRpc.ResultStreamConsumer consumer() { return consumer; } + void registerListener(AsyncResultSet.StreamMessageListener streamMessageListener) { + this.streamMessageListener = Preconditions.checkNotNull(streamMessageListener); + } + public void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) { this.call = call; this.withBeginTransaction = withBeginTransaction; @@ -135,6 +142,7 @@ protected final PartialResultSet computeNext() { private void addToStream(PartialResultSet results) { // We assume that nothing from the user will interrupt gRPC event threads. Uninterruptibles.putUninterruptibly(stream, results); + onStreamMessage(results); } private class ConsumerImpl implements SpannerRpc.ResultStreamConsumer { @@ -182,4 +190,9 @@ public boolean cancelQueryWhenClientIsClosed() { return this.cancelQueryWhenClientIsClosed; } } + + private void onStreamMessage(PartialResultSet partialResultSet) { + Optional.ofNullable(streamMessageListener) + .ifPresent(sl -> sl.onStreamMessage(partialResultSet, stream.remainingCapacity() <= 1)); + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java index 1a3df8b912..24c431eec3 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcValueIterator.java @@ -127,6 +127,10 @@ ResultSetStats getStats() { return statistics; } + boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { + return stream.initiateStreaming(streamMessageListener); + } + Type type() { checkState(type != null, "metadata has not been received"); return type; diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java index 3e82ab7d5f..39165da2d3 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ResumableStreamIterator.java @@ -23,6 +23,7 @@ import com.google.api.client.util.BackOff; import com.google.api.client.util.ExponentialBackOff; +import com.google.api.core.InternalApi; import com.google.api.gax.grpc.GrpcStatusCode; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.StatusCode.Code; @@ -58,6 +59,7 @@ abstract class ResumableStreamIterator extends AbstractIterator retryableCodes; private static final Logger logger = Logger.getLogger(ResumableStreamIterator.class.getName()); @@ -196,7 +198,8 @@ public void execute(Runnable command) { } } - abstract CloseableIterator startStream(@Nullable ByteString resumeToken); + abstract CloseableIterator startStream( + @Nullable ByteString resumeToken, AsyncResultSet.StreamMessageListener streamMessageListener); /** * Prepares the iterator for a retry on a different gRPC channel. Returns true if that is @@ -220,23 +223,21 @@ public boolean isWithBeginTransaction() { return stream != null && stream.isWithBeginTransaction(); } + @Override + @InternalApi + public boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener) { + this.streamMessageListener = streamMessageListener; + startGrpcStreaming(); + return true; + } + @Override protected PartialResultSet computeNext() { int numAttemptsOnOtherChannel = 0; Context context = Context.current(); while (true) { // Eagerly start stream before consuming any buffered items. - if (stream == null) { - span.addAnnotation( - "Starting/Resuming stream", - "ResumeToken", - resumeToken == null ? "null" : resumeToken.toStringUtf8()); - try (IScope scope = tracer.withSpan(span)) { - // When start a new stream set the Span as current to make the gRPC Span a child of - // this Span. - stream = checkNotNull(startStream(resumeToken)); - } - } + startGrpcStreaming(); // Buffer contains items up to a resume token or has reached capacity: flush. if (!buffer.isEmpty() && (finished || !safeToRetry || !buffer.getLast().getResumeToken().isEmpty())) { @@ -315,6 +316,20 @@ && prepareIteratorForRetryOnDifferentGrpcChannel()) { } } + private void startGrpcStreaming() { + if (stream == null) { + span.addAnnotation( + "Starting/Resuming stream", + "ResumeToken", + resumeToken == null ? "null" : resumeToken.toStringUtf8()); + try (IScope scope = tracer.withSpan(span)) { + // When start a new stream set the Span as current to make the gRPC Span a child of + // this Span. + stream = checkNotNull(startStream(resumeToken, streamMessageListener)); + } + } + } + boolean isRetryable(SpannerException spannerException) { return spannerException.isRetryable() || retryableCodes.contains( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingResultSet.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingResultSet.java new file mode 100644 index 0000000000..47b10d852c --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingResultSet.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.core.InternalApi; + +/** Streaming implementation of ResultSet that supports streaming of chunks */ +interface StreamingResultSet extends ResultSet { + + /** + * Returns the {@link boolean} for this {@link ResultSet}. This method will be used by + * AsyncResultSet internally to initiate gRPC streaming. This method should not be called by the + * users. + */ + @InternalApi + boolean initiateStreaming(AsyncResultSet.StreamMessageListener streamMessageListener); +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingUtil.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingUtil.java new file mode 100644 index 0000000000..54496d39f9 --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/StreamingUtil.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +final class StreamingUtil { + + private StreamingUtil() {} + + static boolean initiateStreaming( + ResultSet resultSet, AsyncResultSet.StreamMessageListener streamMessageListener) { + if (resultSet instanceof StreamingResultSet) { + return ((StreamingResultSet) resultSet).initiateStreaming(streamMessageListener); + } + return false; + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java index 98497fbf14..0ba924ef74 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncResultSetImplTest.java @@ -22,7 +22,9 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; import com.google.api.core.ApiFuture; @@ -32,6 +34,9 @@ import com.google.cloud.spanner.AsyncResultSet.ReadyCallback; import com.google.common.base.Function; import com.google.common.collect.Range; +import com.google.protobuf.ByteString; +import com.google.protobuf.Value; +import com.google.spanner.v1.PartialResultSet; import java.util.List; import java.util.concurrent.BlockingDeque; import java.util.concurrent.CountDownLatch; @@ -48,6 +53,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -381,13 +387,20 @@ public Boolean answer(InvocationOnMock invocation) throws Throwable { public void testCallbackIsNotCalledWhilePausedAndCanceled() throws InterruptedException, ExecutionException { Executor executor = Executors.newSingleThreadExecutor(); - ResultSet delegate = mock(ResultSet.class); + StreamingResultSet delegate = mock(StreamingResultSet.class); final AtomicInteger callbackCounter = new AtomicInteger(); ApiFuture callbackResult; try (AsyncResultSetImpl rs = new AsyncResultSetImpl(simpleProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + + when(delegate.initiateStreaming(any(AsyncResultSet.StreamMessageListener.class))) + .thenAnswer( + answer -> { + rs.onStreamMessage(PartialResultSet.newBuilder().build(), false); + return null; + }); callbackResult = rs.setCallback( executor, @@ -498,4 +511,60 @@ public void callbackReturnsDoneBeforeEnd_shouldStopIteration() throws Exception rs.getResult().get(10L, TimeUnit.SECONDS); } } + + @Test + public void testOnStreamMessageWhenResumeTokenIsPresent() { + StreamingResultSet delegate = mock(StreamingResultSet.class); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(mockedProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + // Marking Streaming as supported + Mockito.when( + delegate.initiateStreaming(Mockito.any(AsyncResultSet.StreamMessageListener.class))) + .thenReturn(true); + + rs.setCallback(Executors.newSingleThreadExecutor(), ignored -> CallbackResponse.DONE); + rs.onStreamMessage( + PartialResultSet.newBuilder().addValues(Value.newBuilder().build()).build(), false); + + rs.onStreamMessage( + PartialResultSet.newBuilder().setResumeToken(ByteString.copyFromUtf8("test")).build(), + false); + Mockito.verify(mockedProvider.getExecutor(), times(2)).execute(Mockito.any()); + } + } + + @Test + public void testOnStreamMessageWhenCurrentBufferSizeReachedPrefetchChunkSize() { + StreamingResultSet delegate = mock(StreamingResultSet.class); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(mockedProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + // Marking Streaming as supported + Mockito.when( + delegate.initiateStreaming(Mockito.any(AsyncResultSet.StreamMessageListener.class))) + .thenReturn(true); + + rs.setCallback(Executors.newSingleThreadExecutor(), ignored -> CallbackResponse.DONE); + rs.onStreamMessage( + PartialResultSet.newBuilder().addValues(Value.newBuilder().build()).build(), true); + Mockito.verify(mockedProvider.getExecutor(), times(2)).execute(Mockito.any()); + } + } + + @Test + public void testOnStreamMessageWhenAsyncResultIsCancelled() { + StreamingResultSet delegate = mock(StreamingResultSet.class); + try (AsyncResultSetImpl rs = + new AsyncResultSetImpl(mockedProvider, delegate, AsyncResultSetImpl.DEFAULT_BUFFER_SIZE)) { + // Marking Streaming as supported + Mockito.when( + delegate.initiateStreaming(Mockito.any(AsyncResultSet.StreamMessageListener.class))) + .thenReturn(true); + + rs.setCallback(Executors.newSingleThreadExecutor(), ignored -> CallbackResponse.DONE); + rs.cancel(); + rs.onStreamMessage( + PartialResultSet.newBuilder().addValues(Value.newBuilder().build()).build(), false); + Mockito.verify(mockedProvider.getExecutor(), times(2)).execute(Mockito.any()); + } + } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java index d126719ebb..ebe8672467 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ResumableStreamIteratorTest.java @@ -64,7 +64,8 @@ public class ResumableStreamIteratorTest { interface Starter { AbstractResultSet.CloseableIterator startStream( - @Nullable ByteString resumeToken); + @Nullable ByteString resumeToken, + AsyncResultSet.StreamMessageListener streamMessageListener); } interface ResultSetStream { @@ -164,8 +165,9 @@ private void initWithLimit(int maxBufferSize) { SpannerStubSettings.newBuilder().executeStreamingSqlSettings().getRetryableCodes()) { @Override AbstractResultSet.CloseableIterator startStream( - @Nullable ByteString resumeToken) { - return starter.startStream(resumeToken); + @Nullable ByteString resumeToken, + AsyncResultSet.StreamMessageListener streamMessageListener) { + return starter.startStream(resumeToken, null); } }; } @@ -173,7 +175,7 @@ AbstractResultSet.CloseableIterator startStream( @Test public void simple() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(null, "a")) .thenReturn(resultSet(null, "b")) @@ -195,7 +197,7 @@ public void closedOTSpan() { setInternalState(ResumableStreamIterator.class, this.resumableStreamIterator, "span", span); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -218,7 +220,7 @@ public void closedOCSpan() { setInternalState(ResumableStreamIterator.class, this.resumableStreamIterator, "span", span); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -232,14 +234,14 @@ public void closedOCSpan() { @Test public void restart() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"))) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"), null)) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r3"), "c")) @@ -251,7 +253,7 @@ public void restart() { @Test public void restartWithHoldBack() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -260,7 +262,7 @@ public void restartWithHoldBack() { .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"))) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"), null)) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r3"), "c")) @@ -272,7 +274,7 @@ public void restartWithHoldBack() { @Test public void restartWithHoldBackMidStream() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(null, "b")) @@ -281,7 +283,7 @@ public void restartWithHoldBackMidStream() { .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"))) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"), null)) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r3"), "e")) @@ -304,7 +306,7 @@ public void retryableErrorWithoutRetryInfo() throws IOException { ResumableStreamIterator.class, this.resumableStreamIterator, "backOff", backOff); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenThrow( @@ -312,7 +314,7 @@ public void retryableErrorWithoutRetryInfo() throws IOException { ErrorCode.UNAVAILABLE, "failed by test", Status.UNAVAILABLE.asRuntimeException())); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1"))) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1"), null)) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -324,7 +326,7 @@ public void retryableErrorWithoutRetryInfo() throws IOException { @Test public void nonRetryableError() { ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -343,7 +345,7 @@ public void bufferLimitSimple() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(null, "a")) .thenReturn(resultSet(null, "b")) @@ -356,7 +358,7 @@ public void bufferLimitSimpleWithRestartTokens() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) @@ -369,14 +371,14 @@ public void bufferLimitRestart() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(ByteString.copyFromUtf8("r2"), "b")) .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"))) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r2"), null)) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r3"), "c")) @@ -390,13 +392,13 @@ public void bufferLimitRestartWithinLimitAtStartOfResults() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(null, "XXXXXX")) .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s2)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(null, "a")) .thenReturn(resultSet(null, "b")) @@ -409,14 +411,14 @@ public void bufferLimitRestartWithinLimitMidResults() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(null, "XXXXXX")) .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1"))) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r1"), null)) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()) .thenReturn(resultSet(null, "b")) @@ -430,7 +432,7 @@ public void bufferLimitMissingTokensUnsafeToRetry() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(null, "b")) @@ -447,7 +449,7 @@ public void bufferLimitMissingTokensSafeToRetry() { initWithLimit(1); ResultSetStream s1 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(null)).thenReturn(new ResultSetIterator(s1)); + Mockito.when(starter.startStream(null, null)).thenReturn(new ResultSetIterator(s1)); Mockito.when(s1.next()) .thenReturn(resultSet(ByteString.copyFromUtf8("r1"), "a")) .thenReturn(resultSet(null, "b")) @@ -455,7 +457,7 @@ public void bufferLimitMissingTokensSafeToRetry() { .thenThrow(new RetryableException(errorCodeParameter, "failed by test")); ResultSetStream s2 = Mockito.mock(ResultSetStream.class); - Mockito.when(starter.startStream(ByteString.copyFromUtf8("r3"))) + Mockito.when(starter.startStream(ByteString.copyFromUtf8("r3"), null)) .thenReturn(new ResultSetIterator(s2)); Mockito.when(s2.next()).thenReturn(resultSet(null, "d")).thenReturn(null);