Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout enforcement to ProtocolClient #276

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions library/src/main/kotlin/com/connectrpc/ProtocolClientConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,42 @@ package com.connectrpc

import com.connectrpc.compression.CompressionPool
import com.connectrpc.compression.GzipCompressionPool
import com.connectrpc.http.Timeout
import com.connectrpc.protocols.ConnectInterceptor
import com.connectrpc.protocols.GETConfiguration
import com.connectrpc.protocols.GRPCInterceptor
import com.connectrpc.protocols.GRPCWebInterceptor
import com.connectrpc.protocols.NetworkProtocol
import java.net.URI
import kotlin.coroutines.CoroutineContext
import kotlin.time.Duration
import kotlin.time.DurationUnit
import kotlin.time.toDuration

typealias TimeoutOracle = (MethodSpec<*, *>) -> Duration?

/**
* Returns an oracle that provides the given timeouts for unary or stream
* operations, respectively.
*/
fun simpleTimeouts(unaryTimeout: Duration?, streamTimeout: Duration?): TimeoutOracle {
return { methodSpec ->
when (methodSpec.streamType) {
StreamType.UNARY -> unaryTimeout
else -> streamTimeout
}
}
}

/**
* Set of configuration used to set up clients.
*/
class ProtocolClientConfig @JvmOverloads constructor(
// TODO: Use a block-based construction pattern instead of JvmOverloads
// so we can add new fields in the future without having to worry
// about their ordering or potentially breaking compatibility with
// already-compiled byte code.

// The host (e.g., https://connectrpc.com).
val host: String,
// The client to use for performing requests.
Expand All @@ -54,6 +78,17 @@ class ProtocolClientConfig @JvmOverloads constructor(
// blocking will automatically be dispatched using the given context,
// so the caller does not need to worry about it.
val ioCoroutineContext: CoroutineContext? = null,
// A function that is consulted to determine timeouts for each RPC. If
// the function returns null, no timeout is applied. If a non-null value
// is returned, the entire call must complete before it elapses. If the
// call is still active at the end of the timeout period, it is cancelled
// and will result in an exception with a Code.DEADLINE_EXCEEDED code.
//
// The default oracle, if not configured, returns a 10 second timeout for
// all operations.
val timeoutOracle: TimeoutOracle = { 10.toDuration(DurationUnit.SECONDS) },
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this to account for removal of the timeouts in the configuration helper below. So if users use ConnectOkHttpClient.configureClient and forget to set timeouts in this config, they get the same default behavior as the OkHttpClient was providing (except that this default applies to bidirectional streams, whereas the OkHttpClient timeouts do not).

// Schedules timeout actions.
val timeoutScheduler: Timeout.Scheduler = Timeout.DEFAULT_SCHEDULER,
) {
private val internalInterceptorFactoryList = mutableListOf<(ProtocolClientConfig) -> Interceptor>()
private val compressionPools = mutableMapOf<String, CompressionPool>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package com.connectrpc.http
import com.connectrpc.StreamResult
import okio.Buffer

/** A function that cancels an operation when called. */
typealias Cancelable = () -> Unit

/**
Expand Down
13 changes: 12 additions & 1 deletion library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import com.connectrpc.Headers
import com.connectrpc.MethodSpec
import okio.Buffer
import java.net.URL
import kotlin.time.Duration

enum class HTTPMethod(
val string: String,
Expand All @@ -34,6 +35,8 @@ open class HTTPRequest internal constructor(
val url: URL,
// Value to assign to the `content-type` header.
val contentType: String,
// The optional timeout for this request.
val timeout: Duration?,
// Additional outbound headers for the request.
val headers: Headers,
// The method spec associated with the request.
Expand All @@ -51,6 +54,8 @@ fun HTTPRequest.clone(
url: URL = this.url,
// Value to assign to the `content-type` header.
contentType: String = this.contentType,
// The optional timeout for this request.
timeout: Duration? = this.timeout,
// Additional outbound headers for the request.
headers: Headers = this.headers,
// The method spec associated with the request.
Expand All @@ -59,6 +64,7 @@ fun HTTPRequest.clone(
return HTTPRequest(
url,
contentType,
timeout,
headers,
methodSpec,
)
Expand All @@ -73,6 +79,8 @@ class UnaryHTTPRequest(
url: URL,
// Value to assign to the `content-type` header.
contentType: String,
// The optional timeout for this request.
timeout: Duration?,
// Additional outbound headers for the request.
headers: Headers,
// The method spec associated with the request.
Expand All @@ -82,13 +90,15 @@ class UnaryHTTPRequest(
// HTTP method to use with the request.
// Almost always POST, but side effect free unary RPCs may be made with GET.
val httpMethod: HTTPMethod = HTTPMethod.POST,
) : HTTPRequest(url, contentType, headers, methodSpec)
) : HTTPRequest(url, contentType, timeout, headers, methodSpec)

fun UnaryHTTPRequest.clone(
// The URL for the request.
url: URL = this.url,
// Value to assign to the `content-type` header.
contentType: String = this.contentType,
// The optional timeout for this request.
timeout: Duration? = this.timeout,
// Additional outbound headers for the request.
headers: Headers = this.headers,
// The method spec associated with the request.
Expand All @@ -101,6 +111,7 @@ fun UnaryHTTPRequest.clone(
return UnaryHTTPRequest(
url,
contentType,
timeout,
headers,
methodSpec,
message,
Expand Down
86 changes: 86 additions & 0 deletions library/src/main/kotlin/com/connectrpc/http/Timeout.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2022-2023 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.connectrpc.http

import kotlinx.coroutines.delay
import java.util.Timer
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.concurrent.timerTask
import kotlin.time.Duration

/**
* Represents the timeout state for an RPC.
*/
class Timeout private constructor(
private val timeoutAction: Cancelable,
) {
private val done = AtomicBoolean(false)

@Volatile private var triggered: Boolean = false
private var onCancel: Cancelable? = null

/** Returns true if this timeout has lapsed and the associated RPC canceled. */
val timedOut: Boolean
get() = triggered

/**
* Cancels the timeout. Should only be called when the RPC completes before the
* timeout elapses. Returns true if the timeout was canceled or false if either
* it was already previously canceled or has already timed out. The `timedOut`
* property can be queried to distinguish between these two possibilities.
*/
fun cancel(): Boolean {
if (done.compareAndSet(false, true)) {
onCancel?.invoke()
return true
}
return false
}

private fun trigger() {
if (done.compareAndSet(false, true)) {
triggered = true
timeoutAction()
}
}

/** Schedules timeouts for RPCs. */
interface Scheduler {
/**
* Schedules a timeout that should invoke the given action to cancel
* an RPC after the given delay.
*/

fun scheduleTimeout(delay: Duration, action: Cancelable): Timeout
}

companion object {
/**
* A default implementation that a Timer backed by a single daemon thread.
* The thread isn't started until the first cancelation is scheduled.
*/
val DEFAULT_SCHEDULER = object : Scheduler {
override fun scheduleTimeout(delay: Duration, action: Cancelable): Timeout {
val timeout = Timeout(action)
val task = timerTask { timeout.trigger() }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works but we might consider tradeoffs vs. ScheduledThreadPoolExecutor. Probably not a big deal for timeout scheduling.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it. When I was searching, trying to figure out the idiomatic way to do this in Kotlin and Android apps, this was cited more than use of a ScheduledExecutorService. There were also a couple of Android-specific ways to do it, and then there was the coroutine way (that I couldn't get to work correctly -- to just create a coroutine and then delay(...) in the coroutine before executing the action).

For this, there's not really much of a tradeoff -- both solutions are roughly equivalent. Both approaches require creation of a heavyweight thread. Both implementations are similar, using a thread that polls a priority queue and then executes the tasks. The only potentially meaningful difference is that ScheduledExecutorService impls use a DelayQueue and the stuff in java.util.concurrent whereas the Timer just uses intrinsic locks and Object.wait and Object.notify.

But it's super easy to switch if you think that's more appropriate.

timer.value.schedule(task, delay.inWholeMilliseconds)
timeout.onCancel = { task.cancel() }
return timeout
}
}

private val timer = lazy { Timer(Scheduler::class.qualifiedName, true) }
}
}
72 changes: 65 additions & 7 deletions library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import com.connectrpc.http.Cancelable
import com.connectrpc.http.HTTPClientInterface
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.HTTPResponse
import com.connectrpc.http.Timeout
import com.connectrpc.http.UnaryHTTPRequest
import com.connectrpc.http.dispatchIn
import com.connectrpc.http.transform
Expand All @@ -43,6 +44,7 @@ import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.withContext
import okio.Buffer
import java.net.URI
import java.util.concurrent.atomic.AtomicReference
import kotlin.coroutines.resume

/**
Expand Down Expand Up @@ -85,22 +87,47 @@ class ProtocolClient(
} else {
requestCodec.serialize(request)
}
val requestTimeout = config.timeoutOracle(methodSpec)
val unaryRequest = UnaryHTTPRequest(
url = urlFromMethodSpec(methodSpec),
contentType = "application/${requestCodec.encodingName()}",
timeout = requestTimeout,
headers = headers,
methodSpec = methodSpec,
message = requestMessage,
)
val unaryFunc = config.createInterceptorChain()
val finalRequest = unaryFunc.requestFunction(unaryRequest)
val timeoutRef = AtomicReference<Timeout>(null)
val finalOnResult: (ResponseMessage<Output>) -> Unit = handleResult@{ result ->
when (result) {
is ResponseMessage.Failure -> {
val timeout = timeoutRef.get()
if (timeout != null) {
timeout.cancel()
if (result.cause.code == Code.CANCELED && timeout.timedOut) {
onResult(
ResponseMessage.Failure(
cause = ConnectException(Code.DEADLINE_EXCEEDED, exception = result.cause),
headers = result.headers,
trailers = result.trailers,
),
)
return@handleResult
}
}
onResult(result)
}
else -> onResult(result)
}
}
val cancelable = httpClient.unary(finalRequest) httpClientUnary@{ httpResponse ->
val finalResponse: HTTPResponse
try {
finalResponse = unaryFunc.responseFunction(httpResponse)
} catch (ex: Throwable) {
val connEx = asConnectException(ex)
onResult(
finalOnResult(
ResponseMessage.Failure(
connEx,
emptyMap(),
Expand All @@ -110,7 +137,7 @@ class ProtocolClient(
return@httpClientUnary
}
if (finalResponse.cause != null) {
onResult(
finalOnResult(
ResponseMessage.Failure(
finalResponse.cause,
finalResponse.headers,
Expand All @@ -124,7 +151,7 @@ class ProtocolClient(
try {
responseMessage = responseCodec.deserialize(finalResponse.message)
} catch (ex: Exception) {
onResult(
finalOnResult(
ResponseMessage.Failure(
asConnectException(ex, Code.INTERNAL_ERROR),
finalResponse.headers,
Expand All @@ -133,14 +160,17 @@ class ProtocolClient(
)
return@httpClientUnary
}
onResult(
finalOnResult(
ResponseMessage.Success(
responseMessage,
finalResponse.headers,
finalResponse.trailers,
),
)
}
if (requestTimeout != null) {
timeoutRef.set(config.timeoutScheduler.scheduleTimeout(requestTimeout, cancelable))
}
return cancelable
} catch (ex: Exception) {
val connEx = asConnectException(ex)
Expand Down Expand Up @@ -218,14 +248,17 @@ class ProtocolClient(
val responseTrailers = CompletableDeferred<Headers>()
val requestCodec = config.serializationStrategy.codec(methodSpec.requestClass)
val responseCodec = config.serializationStrategy.codec(methodSpec.responseClass)
val requestTimeout = config.timeoutOracle(methodSpec)
val request = HTTPRequest(
url = urlFromMethodSpec(methodSpec),
contentType = "application/connect+${requestCodec.encodingName()}",
timeout = requestTimeout,
headers = headers,
methodSpec = methodSpec,
)
val streamFunc = config.createStreamingInterceptorChain()
val finalRequest = streamFunc.requestFunction(request)
val timeoutRef = AtomicReference<Timeout>(null)
var isComplete = false
val httpStream = httpClient.stream(
request = finalRequest,
Expand Down Expand Up @@ -259,10 +292,18 @@ class ProtocolClient(
streamResult.message,
)
channel.send(message)
} catch (e: Throwable) {
} catch (ex: Throwable) {
isComplete = true
var connEx = asConnectException(ex)
val timeout = timeoutRef.get()
if (timeout != null) {
timeout.cancel()
if (connEx.code == Code.CANCELED && timeout.timedOut) {
connEx = ConnectException(Code.DEADLINE_EXCEEDED, exception = ex)
}
}
try {
channel.close(ConnectException(Code.UNKNOWN, exception = e))
channel.close(connEx)
} finally {
responseTrailers.complete(emptyMap())
}
Expand All @@ -273,14 +314,31 @@ class ProtocolClient(
// This is a no-op if we already received a StreamResult.Headers.
responseHeaders.complete(emptyMap())
isComplete = true
var connEx = streamResult.cause
val timeout = timeoutRef.get()
if (timeout != null) {
timeout.cancel()
if (connEx?.code == Code.CANCELED && timeout.timedOut) {
connEx = ConnectException(Code.DEADLINE_EXCEEDED, exception = streamResult.cause)
}
}
try {
channel.close(streamResult.cause)
channel.close(connEx)
} finally {
responseTrailers.complete(streamResult.trailers)
}
}
}
}
if (requestTimeout != null) {
timeoutRef.set(
config.timeoutScheduler.scheduleTimeout(requestTimeout) {
runBlocking {
channel.close(ConnectException(code = Code.DEADLINE_EXCEEDED, message = "$requestTimeout timeout elapsed"))
}
},
)
}
try {
channel.invokeOnClose {
runBlocking { httpStream.receiveClose() }
Expand Down
Loading