diff --git a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/TracingHTTPClient.kt b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/TracingHTTPClient.kt index f20acd70..a1824313 100644 --- a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/TracingHTTPClient.kt +++ b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/TracingHTTPClient.kt @@ -97,12 +97,12 @@ internal class TracingHTTPClient( return res } - override fun sendClose() { + override suspend fun sendClose() { printer.printlnWithStackTrace("Half-closing stream") delegate.sendClose() } - override fun receiveClose() { + override suspend fun receiveClose() { printer.printlnWithStackTrace("Closing stream") delegate.receiveClose() } diff --git a/library/src/main/kotlin/com/connectrpc/BidirectionalStreamInterface.kt b/library/src/main/kotlin/com/connectrpc/BidirectionalStreamInterface.kt index 87778622..7703aada 100644 --- a/library/src/main/kotlin/com/connectrpc/BidirectionalStreamInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/BidirectionalStreamInterface.kt @@ -67,12 +67,12 @@ interface BidirectionalStreamInterface { /** * Close the send stream. No calls to [send] are valid after calling [sendClose]. */ - fun sendClose() + suspend fun sendClose() /** * Close the receive stream. */ - fun receiveClose() + suspend fun receiveClose() /** * Determine if the underlying client send stream is closed. diff --git a/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt b/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt index 9fe150a5..7975d0e2 100644 --- a/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt @@ -57,13 +57,13 @@ interface ClientOnlyStreamInterface { /** * Close the stream. No calls to [send] are valid after calling [sendClose]. */ - fun sendClose() + suspend fun sendClose() /** * Cancels the stream. This closes both send and receive sides of the stream * without awaiting any server reply. */ - fun cancel() + suspend fun cancel() /** * Determine if the underlying client send stream is closed. diff --git a/library/src/main/kotlin/com/connectrpc/ProtocolClientConfig.kt b/library/src/main/kotlin/com/connectrpc/ProtocolClientConfig.kt index ae38bca6..84b6836c 100644 --- a/library/src/main/kotlin/com/connectrpc/ProtocolClientConfig.kt +++ b/library/src/main/kotlin/com/connectrpc/ProtocolClientConfig.kt @@ -22,6 +22,7 @@ import com.connectrpc.protocols.GRPCInterceptor import com.connectrpc.protocols.GRPCWebInterceptor import com.connectrpc.protocols.NetworkProtocol import java.net.URI +import kotlin.coroutines.CoroutineContext /** * Set of configuration used to set up clients. @@ -45,6 +46,14 @@ class ProtocolClientConfig @JvmOverloads constructor( // Compression pools that provide support for the provided `compressionName`, as well as any // other compression methods that need to be supported for inbound responses. compressionPools: List = listOf(GzipCompressionPool), + // The coroutine context to use for I/O, such as sending RPC messages. + // If null, the current/calling coroutine context is used. So the caller + // may need explicitly dispatch send calls using contexts where I/O is + // appropriate (using the withContext extension function). If non-null + // (such as Dispatchers.IO), operations that involve I/O pr other + // blocking will automatically be dispatched using the biven context, + // so the caller does not need to worry about it. + val ioCoroutineContext: CoroutineContext? = null, ) { private val internalInterceptorFactoryList = mutableListOf<(ProtocolClientConfig) -> Interceptor>() private val compressionPools = mutableMapOf() diff --git a/library/src/main/kotlin/com/connectrpc/ServerOnlyStreamInterface.kt b/library/src/main/kotlin/com/connectrpc/ServerOnlyStreamInterface.kt index f7adbfae..abf9b5b7 100644 --- a/library/src/main/kotlin/com/connectrpc/ServerOnlyStreamInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/ServerOnlyStreamInterface.kt @@ -62,7 +62,7 @@ interface ServerOnlyStreamInterface { /** * Close the receive stream. */ - fun receiveClose() + suspend fun receiveClose() /** * Determine if the underlying client receive stream is closed. diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt index 0ae96877..2424b86f 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt @@ -16,7 +16,6 @@ package com.connectrpc.http import com.connectrpc.StreamResult import okio.Buffer -import java.util.concurrent.atomic.AtomicBoolean typealias Cancelable = () -> Unit @@ -46,92 +45,3 @@ interface HTTPClientInterface { */ fun stream(request: HTTPRequest, duplex: Boolean, onResult: suspend (StreamResult) -> Unit): Stream } - -interface Stream { - suspend fun send(buffer: Buffer): Result - - fun sendClose() - - fun receiveClose() - - fun isSendClosed(): Boolean - - fun isReceiveClosed(): Boolean -} - -fun Stream( - onSend: suspend (Buffer) -> Result, - onSendClose: () -> Unit = {}, - onReceiveClose: () -> Unit = {}, -): Stream { - val isSendClosed = AtomicBoolean() - val isReceiveClosed = AtomicBoolean() - return object : Stream { - override suspend fun send(buffer: Buffer): Result { - if (isSendClosed()) { - return Result.failure(IllegalStateException("cannot send. underlying stream is closed")) - } - return try { - onSend(buffer) - } catch (e: Throwable) { - Result.failure(e) - } - } - - override fun sendClose() { - if (isSendClosed.compareAndSet(false, true)) { - onSendClose() - } - } - - override fun receiveClose() { - if (isReceiveClosed.compareAndSet(false, true)) { - try { - onReceiveClose() - } finally { - // When receive side is closed, the send side is - // implicitly closed as well. - // We don't use sendClose() because we don't want to - // invoke onSendClose() since that will try to actually - // half-close the HTTP stream, which will fail since - // closing the receive side cancels the entire thing. - isSendClosed.set(true) - } - } - } - - override fun isSendClosed(): Boolean { - return isSendClosed.get() - } - - override fun isReceiveClosed(): Boolean { - return isReceiveClosed.get() - } - } -} - -/** - * Returns a new stream that applies the given function to each - * buffer when send is called. The result of that function is - * what is passed along to the original stream. - */ -fun Stream.transform(apply: (Buffer) -> Buffer): Stream { - val delegate = this - return object : Stream { - override suspend fun send(buffer: Buffer): Result { - return delegate.send(apply(buffer)) - } - override fun sendClose() { - delegate.sendClose() - } - override fun receiveClose() { - delegate.receiveClose() - } - override fun isSendClosed(): Boolean { - return delegate.isSendClosed() - } - override fun isReceiveClosed(): Boolean { - return delegate.isReceiveClosed() - } - } -} diff --git a/library/src/main/kotlin/com/connectrpc/http/Stream.kt b/library/src/main/kotlin/com/connectrpc/http/Stream.kt new file mode 100644 index 00000000..f8a92b9a --- /dev/null +++ b/library/src/main/kotlin/com/connectrpc/http/Stream.kt @@ -0,0 +1,145 @@ +// Copyright 2022-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.connectrpc.http + +import kotlinx.coroutines.withContext +import okio.Buffer +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.coroutines.CoroutineContext + +/** + * Stream represents the communications for a single streaming RPC. + * It can be used to send messages and to close the stream. Receiving + * messages is done via callbacks provided when the stream is created. + * + * See HTTPClientInterface#stream. + */ +interface Stream { + suspend fun send(buffer: Buffer): Result + + suspend fun sendClose() + + suspend fun receiveClose() + + fun isSendClosed(): Boolean + + fun isReceiveClosed(): Boolean +} + +/** + * Creates a new stream whose implementation of sending and + * closing is delegated to the given lambdas. + */ +fun Stream( + onSend: suspend (Buffer) -> Result, + onSendClose: suspend () -> Unit = {}, + onReceiveClose: suspend () -> Unit = {}, +): Stream { + val isSendClosed = AtomicBoolean() + val isReceiveClosed = AtomicBoolean() + return object : Stream { + override suspend fun send(buffer: Buffer): Result { + if (isSendClosed()) { + return Result.failure(IllegalStateException("cannot send. underlying stream is closed")) + } + return try { + onSend(buffer) + } catch (e: Throwable) { + Result.failure(e) + } + } + + override suspend fun sendClose() { + if (isSendClosed.compareAndSet(false, true)) { + onSendClose() + } + } + + override suspend fun receiveClose() { + if (isReceiveClosed.compareAndSet(false, true)) { + try { + onReceiveClose() + } finally { + // When receive side is closed, the send side is + // implicitly closed as well. + // We don't use sendClose() because we don't want to + // invoke onSendClose() since that will try to actually + // half-close the HTTP stream, which will fail since + // closing the receive side cancels the entire thing. + isSendClosed.set(true) + } + } + } + + override fun isSendClosed(): Boolean { + return isSendClosed.get() + } + + override fun isReceiveClosed(): Boolean { + return isReceiveClosed.get() + } + } +} + +/** + * Returns a new stream that applies the given function to each + * buffer when send is called. The result of that function is + * what is passed along to the original stream. + */ +fun Stream.transform(apply: (Buffer) -> Buffer): Stream { + val delegate = this + return object : Stream { + override suspend fun send(buffer: Buffer): Result { + return delegate.send(apply(buffer)) + } + override suspend fun sendClose() { + delegate.sendClose() + } + override suspend fun receiveClose() { + delegate.receiveClose() + } + override fun isSendClosed(): Boolean { + return delegate.isSendClosed() + } + override fun isReceiveClosed(): Boolean { + return delegate.isReceiveClosed() + } + } +} + +/** + * Returns a new stream that dispatches suspending operations + * (sending and closing) using the given coroutine context. + */ +fun Stream.dispatchIn(context: CoroutineContext): Stream { + val delegate = this + return object : Stream { + override suspend fun send(buffer: Buffer): Result = withContext(context) { + delegate.send(buffer) + } + override suspend fun sendClose() = withContext(context) { + delegate.sendClose() + } + override suspend fun receiveClose() = withContext(context) { + delegate.receiveClose() + } + override fun isSendClosed(): Boolean { + return delegate.isSendClosed() + } + override fun isReceiveClosed(): Boolean { + return delegate.isReceiveClosed() + } + } +} diff --git a/library/src/main/kotlin/com/connectrpc/impl/BidirectionalStream.kt b/library/src/main/kotlin/com/connectrpc/impl/BidirectionalStream.kt index df2d01b8..45b56eba 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/BidirectionalStream.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/BidirectionalStream.kt @@ -59,11 +59,11 @@ internal class BidirectionalStream( return stream.isReceiveClosed() } - override fun sendClose() { + override suspend fun sendClose() { stream.sendClose() } - override fun receiveClose() { + override suspend fun receiveClose() { stream.receiveClose() } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt b/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt index 583793f6..0edbf07e 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt @@ -54,11 +54,11 @@ internal class ClientOnlyStream( return messageStream.responseTrailers() } - override fun sendClose() { + override suspend fun sendClose() { return messageStream.sendClose() } - override fun cancel() { + override suspend fun cancel() { return messageStream.receiveClose() } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 07807730..909b5be8 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -31,11 +31,14 @@ import com.connectrpc.http.Cancelable import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest import com.connectrpc.http.UnaryHTTPRequest +import com.connectrpc.http.dispatchIn import com.connectrpc.http.transform import com.connectrpc.protocols.GETConfiguration import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.suspendCancellableCoroutine +import kotlinx.coroutines.withContext import java.net.URI import java.util.concurrent.CountDownLatch import kotlin.coroutines.resume @@ -138,6 +141,19 @@ class ProtocolClient( request: Input, headers: Headers, methodSpec: MethodSpec, + ): ResponseMessage { + if (config.ioCoroutineContext != null) { + return withContext(config.ioCoroutineContext) { + suspendUnary(request, headers, methodSpec) + } + } + return suspendUnary(request, headers, methodSpec) + } + + private suspend fun suspendUnary( + request: Input, + headers: Headers, + methodSpec: MethodSpec, ): ResponseMessage { return suspendCancellableCoroutine { continuation -> val cancelable = unary(request, headers, methodSpec) { responseMessage -> @@ -168,18 +184,11 @@ class ProtocolClient( return call } - override suspend fun stream( - headers: Headers, - methodSpec: MethodSpec, - ): BidirectionalStreamInterface { - return bidirectionalStream(methodSpec, headers) - } - override suspend fun serverStream( headers: Headers, methodSpec: MethodSpec, ): ServerOnlyStreamInterface { - val stream = bidirectionalStream(methodSpec, headers) + val stream = stream(headers, methodSpec) return ServerOnlyStream(stream) } @@ -191,10 +200,10 @@ class ProtocolClient( return ClientOnlyStream(stream) } - private suspend fun bidirectionalStream( - methodSpec: MethodSpec, + override suspend fun stream( headers: Headers, - ): BidirectionalStream = suspendCancellableCoroutine { continuation -> + methodSpec: MethodSpec, + ): BidirectionalStreamInterface { val channel = Channel(1) val responseHeaders = CompletableDeferred() val responseTrailers = CompletableDeferred() @@ -209,10 +218,13 @@ class ProtocolClient( val streamFunc = config.createStreamingInterceptorChain() val finalRequest = streamFunc.requestFunction(request) var isComplete = false - val httpStream = httpClient.stream(finalRequest, methodSpec.streamType == StreamType.BIDI) { initialResult -> + val httpStream = httpClient.stream( + finalRequest, + methodSpec.streamType == StreamType.BIDI, + ) httpStream@{ initialResult -> if (isComplete) { // No-op on remaining handlers after a completion. - return@stream + return@httpStream } // Pass through the interceptor chain. when (val streamResult = streamFunc.streamResultFunction(initialResult)) { @@ -257,22 +269,27 @@ class ProtocolClient( } } } - continuation.invokeOnCancellation { - httpStream.receiveClose() - } - val stream = httpStream.transform { streamFunc.requestBodyFunction(it) } - channel.invokeOnClose { - stream.receiveClose() - } - continuation.resume( - BidirectionalStream( + try { + channel.invokeOnClose { + runBlocking { httpStream.receiveClose() } + } + var stream = httpStream.transform { streamFunc.requestBodyFunction(it) } + if (config.ioCoroutineContext != null) { + stream = stream.dispatchIn(config.ioCoroutineContext) + } + return BidirectionalStream( stream, requestCodec, channel, responseHeaders, responseTrailers, - ), - ) + ) + } catch (ex: Throwable) { + // If something in these last steps prevents us + // from returning, don't leak the stream. + httpStream.receiveClose() + throw ex + } } private fun urlFromMethodSpec(methodSpec: MethodSpec<*, *>) = baseURIWithTrailingSlash.resolve(methodSpec.path).toURL() diff --git a/library/src/main/kotlin/com/connectrpc/impl/ServerOnlyStream.kt b/library/src/main/kotlin/com/connectrpc/impl/ServerOnlyStream.kt index d1fa9b3e..e04eb38c 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ServerOnlyStream.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ServerOnlyStream.kt @@ -46,7 +46,7 @@ internal class ServerOnlyStream( } } - override fun receiveClose() { + override suspend fun receiveClose() { messageStream.receiveClose() }