diff --git a/library/src/main/kotlin/build/buf/connect/ClientOnlyStreamInterface.kt b/library/src/main/kotlin/build/buf/connect/ClientOnlyStreamInterface.kt index 4513b786..d63b64b6 100644 --- a/library/src/main/kotlin/build/buf/connect/ClientOnlyStreamInterface.kt +++ b/library/src/main/kotlin/build/buf/connect/ClientOnlyStreamInterface.kt @@ -19,6 +19,10 @@ package build.buf.connect * eventually receives a response) that can send request messages and initiate closes. */ interface ClientOnlyStreamInterface { + /** + * + */ + suspend fun receiveAndClose(): StreamResult /** * Send a request to the server over the stream. * diff --git a/library/src/main/kotlin/build/buf/connect/http/HTTPClientInterface.kt b/library/src/main/kotlin/build/buf/connect/http/HTTPClientInterface.kt index 796b1f44..9e77cbfd 100644 --- a/library/src/main/kotlin/build/buf/connect/http/HTTPClientInterface.kt +++ b/library/src/main/kotlin/build/buf/connect/http/HTTPClientInterface.kt @@ -49,9 +49,11 @@ interface HTTPClientInterface { class Stream( private val onSend: (Buffer) -> Unit, - private val onClose: () -> Unit + private val onSendClose: () -> Unit = {}, + private val onReceiveClose: () -> Unit = {} ) { - private val isClosed = AtomicReference(false) + private val isSendClosed = AtomicReference(false) + private val isReceiveClosed = AtomicReference(false) fun send(buffer: Buffer): Result { if (isClosed()) { @@ -65,13 +67,19 @@ class Stream( } } - fun close() { - if (!isClosed.getAndSet(true)) { - onClose() + fun sendClose() { + if (!isSendClosed.getAndSet(true)) { + onSendClose() + } + } + + fun receiveClose() { + if (!isReceiveClosed.getAndSet(true)) { + onReceiveClose() } } fun isClosed(): Boolean { - return isClosed.get() + return isSendClosed.get() } } diff --git a/library/src/main/kotlin/build/buf/connect/impl/BidirectionalStream.kt b/library/src/main/kotlin/build/buf/connect/impl/BidirectionalStream.kt index e452e11c..d00e5691 100644 --- a/library/src/main/kotlin/build/buf/connect/impl/BidirectionalStream.kt +++ b/library/src/main/kotlin/build/buf/connect/impl/BidirectionalStream.kt @@ -45,7 +45,7 @@ internal class BidirectionalStream( } override fun close() { - stream.close() + stream.sendClose() } override fun isClosed(): Boolean { diff --git a/library/src/main/kotlin/build/buf/connect/impl/ClientOnlyStream.kt b/library/src/main/kotlin/build/buf/connect/impl/ClientOnlyStream.kt index 9e683c97..22489645 100644 --- a/library/src/main/kotlin/build/buf/connect/impl/ClientOnlyStream.kt +++ b/library/src/main/kotlin/build/buf/connect/impl/ClientOnlyStream.kt @@ -16,6 +16,7 @@ package build.buf.connect.impl import build.buf.connect.BidirectionalStreamInterface import build.buf.connect.ClientOnlyStreamInterface +import build.buf.connect.StreamResult /** * Concrete implementation of [ClientOnlyStreamInterface]. @@ -27,6 +28,15 @@ internal class ClientOnlyStream( return messageStream.send(input) } + override suspend fun receiveAndClose(): StreamResult { + val resultChannel = messageStream.resultChannel() + try { + return resultChannel.receive() + } finally { + resultChannel.cancel() + } + } + override fun close() { messageStream.close() } diff --git a/library/src/main/kotlin/build/buf/connect/impl/ProtocolClient.kt b/library/src/main/kotlin/build/buf/connect/impl/ProtocolClient.kt index d3e96f22..bd19800a 100644 --- a/library/src/main/kotlin/build/buf/connect/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/build/buf/connect/impl/ProtocolClient.kt @@ -191,18 +191,19 @@ class ProtocolClient( channel.send(result) } continuation.invokeOnCancellation { - httpStream.close() + httpStream.sendClose() + } + val stream = Stream( + onSend = { buffer -> + httpStream.send(streamFunc.requestBodyFunction(buffer)) + } + ) + channel.invokeOnClose { + stream.receiveClose() } continuation.resume( BidirectionalStream( - Stream( - onSend = { buffer -> - httpStream.send(streamFunc.requestBodyFunction(buffer)) - }, - onClose = { - httpStream.close() - } - ), + stream, requestCodec, channel ) diff --git a/okhttp/src/main/kotlin/build/buf/connect/okhttp/OkHttpStream.kt b/okhttp/src/main/kotlin/build/buf/connect/okhttp/OkHttpStream.kt index 08eb8b62..b786d7b1 100644 --- a/okhttp/src/main/kotlin/build/buf/connect/okhttp/OkHttpStream.kt +++ b/okhttp/src/main/kotlin/build/buf/connect/okhttp/OkHttpStream.kt @@ -48,7 +48,8 @@ internal fun OkHttpClient.initializeStream( request: HTTPRequest, onResult: suspend (StreamResult) -> Unit ): Stream { - val isClosed = AtomicBoolean(false) + val isSendClosed = AtomicBoolean(false) + val isReceiveClosed = AtomicBoolean(false) val duplexRequestBody = PipeDuplexRequestBody(request.contentType.toMediaType()) val builder = Request.Builder() .url(request.url) @@ -60,23 +61,21 @@ internal fun OkHttpClient.initializeStream( } val callRequest = builder.build() val call = newCall(callRequest) - call.enqueue(ResponseCallback(onResult, isClosed)) + call.enqueue(ResponseCallback(onResult, isSendClosed)) return Stream( onSend = { buffer -> - if (!isClosed.get()) { + if (!isSendClosed.get()) { duplexRequestBody.forConsume(buffer) } }, - onClose = { - try { - isClosed.set(true) - call.cancel() - duplexRequestBody.close() - } catch (_: Throwable) { - // No-op - } + onSendClose = { + isSendClosed.set(true) + duplexRequestBody.close() } - ) + ) { + isReceiveClosed.set(true) + call.cancel() + } } private class ResponseCallback(