diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index b645b75e056a..d0436591515c 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -25,6 +25,7 @@ import com.google.api.client.util.BackOff; import com.google.api.client.util.ExponentialBackOff; import com.google.api.gax.paging.Page; +import com.google.api.gax.rpc.ServerStream; import com.google.api.pathtemplate.PathTemplate; import com.google.cloud.BaseService; import com.google.cloud.ByteArray; @@ -828,7 +829,7 @@ public ReadContext singleUse() { @Override public ReadContext singleUse(TimestampBound bound) { - return setActive(new SingleReadContext(this, bound, rawGrpcRpc, defaultPrefetchChunks)); + return setActive(new SingleReadContext(this, bound, gapicRpc, defaultPrefetchChunks)); } @Override @@ -839,7 +840,7 @@ public ReadOnlyTransaction singleUseReadOnlyTransaction() { @Override public ReadOnlyTransaction singleUseReadOnlyTransaction(TimestampBound bound) { return setActive( - new SingleUseReadOnlyTransaction(this, bound, rawGrpcRpc, defaultPrefetchChunks)); + new SingleUseReadOnlyTransaction(this, bound, gapicRpc, defaultPrefetchChunks)); } @Override @@ -850,12 +851,12 @@ public ReadOnlyTransaction readOnlyTransaction() { @Override public ReadOnlyTransaction readOnlyTransaction(TimestampBound bound) { return setActive( - new MultiUseReadOnlyTransaction(this, bound, rawGrpcRpc, defaultPrefetchChunks)); + new MultiUseReadOnlyTransaction(this, bound, gapicRpc, defaultPrefetchChunks)); } @Override public TransactionRunner readWriteTransaction() { - return setActive(new TransactionRunnerImpl(this, rawGrpcRpc, defaultPrefetchChunks)); + return setActive(new TransactionRunnerImpl(this, gapicRpc, defaultPrefetchChunks)); } @Override @@ -1055,20 +1056,14 @@ ResultSet executeQueryInternalWithOptions( new ResumableStreamIterator(MAX_BUFFERED_CHUNKS, QUERY) { @Override CloseableIterator startStream(@Nullable ByteString resumeToken) { - GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks); - SpannerRpc.StreamingCall call = - rpc.executeQuery( - resumeToken == null - ? request - : request.toBuilder().setResumeToken(resumeToken).build(), - stream.consumer(), - session.options); - // We get one message for free. - if (prefetchChunks > 1) { - call.request(prefetchChunks - 1); - } - stream.setCall(call); - return stream; + return new CloseableServerStreamIterator(rpc.executeQuery( + resumeToken == null + ? request + : request.toBuilder().setResumeToken(resumeToken).build(), + null, + session.options)); + + // let resume fail for now } }; return new GrpcResultSet(stream, this, queryMode); @@ -1168,20 +1163,14 @@ ResultSet readInternalWithOptions( new ResumableStreamIterator(MAX_BUFFERED_CHUNKS, READ) { @Override CloseableIterator startStream(@Nullable ByteString resumeToken) { - GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks); - SpannerRpc.StreamingCall call = - rpc.read( - resumeToken == null - ? request - : request.toBuilder().setResumeToken(resumeToken).build(), - stream.consumer(), - session.options); - // We get one message for free. - if (prefetchChunks > 1) { - call.request(prefetchChunks - 1); - } - stream.setCall(call); - return stream; + return new CloseableServerStreamIterator(rpc.read( + resumeToken == null + ? request + : request.toBuilder().setResumeToken(resumeToken).build(), + null, + session.options)); + + // let resume fail for now } }; GrpcResultSet resultSet = @@ -2287,6 +2276,32 @@ interface CloseableIterator extends Iterator { void close(@Nullable String message); } + private static final class CloseableServerStreamIterator implements CloseableIterator { + + private final ServerStream stream; + private final Iterator iterator; + + public CloseableServerStreamIterator(ServerStream stream) { + this.stream = stream; + this.iterator = stream.iterator(); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public T next() { + return iterator.next(); + } + + @Override + public void close(@Nullable String message) { + stream.cancel(); + } + } + /** Adapts a streaming read/query call into an iterator over partial result sets. */ @VisibleForTesting static class GrpcStreamIterator extends AbstractIterator diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 0872a84f08d9..902650a8d56b 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -21,10 +21,12 @@ import com.google.api.gax.core.CredentialsProvider; import com.google.api.gax.core.GaxProperties; import com.google.api.gax.grpc.GaxGrpcProperties; +import com.google.api.gax.grpc.GrpcCallContext; import com.google.api.gax.grpc.GrpcTransportChannel; import com.google.api.gax.rpc.ApiClientHeaderProvider; import com.google.api.gax.rpc.FixedTransportChannelProvider; import com.google.api.gax.rpc.HeaderProvider; +import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.TransportChannelProvider; import com.google.api.pathtemplate.PathTemplate; import com.google.cloud.ServiceOptions; @@ -72,6 +74,7 @@ import com.google.spanner.v1.PartitionQueryRequest; import com.google.spanner.v1.PartitionReadRequest; import com.google.spanner.v1.PartitionResponse; +import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.ReadRequest; import com.google.spanner.v1.RollbackRequest; import com.google.spanner.v1.Session; @@ -335,15 +338,19 @@ public void deleteSession(String sessionName, @Nullable Map options) } @Override - public StreamingCall read( + public ServerStream read( ReadRequest request, ResultStreamConsumer consumer, @Nullable Map options) { - throw new UnsupportedOperationException("Not implemented yet."); + GrpcCallContext context = GrpcCallContext.createDefault() + .withChannelAffinity(Option.CHANNEL_HINT.getLong(options).intValue()); + return stub.streamingReadCallable().call(request, context); } @Override - public StreamingCall executeQuery( + public ServerStream executeQuery( ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map options) { - throw new UnsupportedOperationException("Not implemented yet."); + GrpcCallContext context = GrpcCallContext.createDefault() + .withChannelAffinity(Option.CHANNEL_HINT.getLong(options).intValue()); + return stub.executeStreamingSqlCallable().call(request, context); } @Override diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcSpannerRpc.java index d6b3e132e505..e91b140ef4a0 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GrpcSpannerRpc.java @@ -22,6 +22,7 @@ import com.google.api.gax.grpc.GaxGrpcProperties; import com.google.api.gax.rpc.ApiClientHeaderProvider; import com.google.api.gax.rpc.HeaderProvider; +import com.google.api.gax.rpc.ServerStream; import com.google.api.pathtemplate.PathTemplate; import com.google.cloud.NoCredentials; import com.google.cloud.ServiceOptions; @@ -366,25 +367,15 @@ public void deleteSession(String sessionName, @Nullable Map options) } @Override - public StreamingCall read( + public ServerStream read( ReadRequest request, ResultStreamConsumer consumer, @Nullable Map options) { - return doStreamingCall( - SpannerGrpc.METHOD_STREAMING_READ, - request, - consumer, - request.getSession(), - Option.CHANNEL_HINT.getLong(options)); + throw new UnsupportedOperationException("Not implemented: read"); } @Override - public StreamingCall executeQuery( + public ServerStream executeQuery( ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map options) { - return doStreamingCall( - SpannerGrpc.METHOD_EXECUTE_STREAMING_SQL, - request, - consumer, - request.getSession(), - Option.CHANNEL_HINT.getLong(options)); + throw new UnsupportedOperationException("Not implemented: executeQuery"); } @Override diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java index 9d7066c5e55c..2c0ba3be4cc1 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner.spi.v1; +import com.google.api.gax.rpc.ServerStream; import com.google.cloud.ServiceRpc; import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.spi.v1.SpannerRpc.Option; @@ -197,10 +198,10 @@ Session createSession(String databaseName, @Nullable Map labels, void deleteSession(String sessionName, @Nullable Map options) throws SpannerException; - StreamingCall read( + ServerStream read( ReadRequest request, ResultStreamConsumer consumer, @Nullable Map options); - StreamingCall executeQuery( + ServerStream executeQuery( ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map options); Transaction beginTransaction(BeginTransactionRequest request, @Nullable Map options) diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ServerStreamingStashCallable.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ServerStreamingStashCallable.java new file mode 100644 index 000000000000..026d10f95cf0 --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ServerStreamingStashCallable.java @@ -0,0 +1,116 @@ +/* + * Copyright 2018 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.gax.rpc.ApiCallContext; +import com.google.api.gax.rpc.ResponseObserver; +import com.google.api.gax.rpc.ServerStream; +import com.google.api.gax.rpc.ServerStreamingCallable; +import com.google.api.gax.rpc.StreamController; +import com.google.common.base.Preconditions; +import com.google.common.collect.Queues; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CancellationException; + +public class ServerStreamingStashCallable + extends ServerStreamingCallable { + private List responseList; + + public ServerStreamingStashCallable() { + responseList = new ArrayList<>(); + } + + public ServerStreamingStashCallable(List responseList) { + this.responseList = responseList; + } + + @Override + public void call( + RequestT request, ResponseObserver responseObserver, ApiCallContext context) { + Preconditions.checkNotNull(responseObserver); + + StreamControllerStash controller = + new StreamControllerStash<>(responseList, responseObserver); + controller.start(); + } + + // Minimal implementation of back pressure aware stream controller. Not threadsafe + private static class StreamControllerStash implements StreamController { + final ResponseObserver observer; + final Queue queue; + boolean autoFlowControl = true; + long numPending; + Throwable error; + boolean delivering, closed; + + public StreamControllerStash( + List responseList, ResponseObserver observer) { + this.observer = observer; + this.queue = Queues.newArrayDeque(responseList); + } + + public void start() { + observer.onStart(this); + if (autoFlowControl) { + numPending = Integer.MAX_VALUE; + } + deliver(); + } + + @Override + public void disableAutoInboundFlowControl() { + autoFlowControl = false; + } + + @Override + public void request(int count) { + numPending += count; + deliver(); + } + + @Override + public void cancel() { + error = new CancellationException("User cancelled stream"); + deliver(); + } + + private void deliver() { + if (delivering || closed) return; + delivering = true; + + try { + while (error == null && numPending > 0 && !queue.isEmpty()) { + numPending--; + observer.onResponse(queue.poll()); + } + + if (error != null || queue.isEmpty()) { + if (error != null) { + observer.onError(error); + } else { + observer.onComplete(); + } + closed = true; + } + } finally { + delivering = false; + } + } + } +} \ No newline at end of file diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java index 08e1bf1e14f3..e0cee1bcb506 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java @@ -19,6 +19,8 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.api.gax.rpc.ServerStream; +import com.google.api.gax.rpc.ServerStreamingCallable; import com.google.cloud.Timestamp; import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.protobuf.ByteString; @@ -280,18 +282,15 @@ public void request(int numMessages) {} } private void mockRead(final PartialResultSet myResultSet) { - final ArgumentCaptor consumer = - ArgumentCaptor.forClass(SpannerRpc.ResultStreamConsumer.class); - Mockito.when(rpc.read(Mockito.any(), consumer.capture(), Mockito.eq(options))) - .then( - new Answer() { - @Override - public SpannerRpc.StreamingCall answer(InvocationOnMock invocation) throws Throwable { - consumer.getValue().onPartialResultSet(myResultSet); - consumer.getValue().onCompleted(); - return new NoOpStreamingCall(); - } - }); + ServerStreamingCallable serverStreamingCallable = + new ServerStreamingStashCallable(Arrays.asList(myResultSet)); + final ServerStream mockServerStream = serverStreamingCallable.call(null); + Mockito.when( + rpc.read( + Mockito.any(), + Mockito.any(), + Mockito.eq(options))) + .thenReturn(mockServerStream); } @Test