Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rt): enforce only once shutdown logic for crt engine connections #497

Merged
merged 3 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,36 +76,30 @@ public class CrtHttpEngine(public val config: CrtHttpEngineConfig) : HttpClientE
override suspend fun roundTrip(request: HttpRequest): HttpCall {
val callContext = callContext()
val manager = getManagerForUri(request.uri)

// LIFETIME: connection will be released back to the pool/manager when
// the response completes OR on exception (both handled by the completion handler registered on the stream
// handler)
val conn = withTimeoutOrNull(config.connectionAcquireTimeout) {
manager.acquireConnection()
} ?: throw ClientException("timed out waiting for an HTTP connection to be acquired from the pool")

try {
val reqTime = Instant.now()
val engineRequest = request.toCrtRequest(callContext)

// LIFETIME: connection will be released back to the pool/manager when
// the response completes OR on exception
val respHandler = SdkStreamResponseHandler(conn)
callContext.job.invokeOnCompletion {
// ensures the stream is driven to completion regardless of what the downstream consumer does
respHandler.complete()
}

val stream = conn.makeRequest(engineRequest, respHandler)
stream.activate()

val resp = respHandler.waitForResponse()

return HttpCall(request, resp, reqTime, Instant.now(), callContext)
} catch (ex: Exception) {
try {
manager.releaseConnection(conn)
} catch (ex2: Exception) {
ex.addSuppressed(ex2)
}
throw ex
val respHandler = SdkStreamResponseHandler(conn)
callContext.job.invokeOnCompletion {
logger.trace { "completing handler; cause=$it" }
// ensures the stream is driven to completion regardless of what the downstream consumer does
respHandler.complete()
}

val reqTime = Instant.now()
val engineRequest = request.toCrtRequest(callContext)

val stream = conn.makeRequest(engineRequest, respHandler)
stream.activate()

val resp = respHandler.waitForResponse()

return HttpCall(request, resp, reqTime, Instant.now(), callContext)
}

override fun shutdown() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ public class CrtHttpEngineConfig private constructor(builder: Builder) : HttpCli
* The default engine config. Most clients should use this.
*/
public val Default: CrtHttpEngineConfig = CrtHttpEngineConfig(Builder())

public operator fun invoke(block: Builder.() -> Unit): CrtHttpEngineConfig =
Builder().apply(block).build()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import aws.smithy.kotlin.runtime.http.*
import aws.smithy.kotlin.runtime.http.HeadersBuilder
import aws.smithy.kotlin.runtime.http.response.HttpResponse
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.logging.Logger
import kotlinx.atomicfu.locks.reentrantLock
import kotlinx.atomicfu.locks.withLock
import kotlinx.coroutines.ExperimentalCoroutinesApi
Expand All @@ -30,13 +31,17 @@ internal class SdkStreamResponseHandler(
// There is no great way to do that currently without either (1) closing the connection or (2) throwing an
// exception from a callback such that AWS_OP_ERROR is returned. Wait for HttpStream to have explicit cancellation

private val logger = Logger.getLogger<SdkStreamResponseHandler>()
private val responseReady = Channel<HttpResponse>(1)
private val headers = HeadersBuilder()

private var sdkBody: BufferedReadChannel? = null

private val lock = reentrantLock()
private val lock = reentrantLock() // protects crtStream and cancelled state
private var crtStream: HttpStream? = null
// if the (coroutine) job is completed before the stream's onResponseComplete callback is
// invoked (for any reason) we consider the stream "cancelled"
private var cancelled = false

private val Int.isMainHeadersBlock: Boolean
get() = when (this) {
Expand Down Expand Up @@ -115,7 +120,13 @@ internal class SdkStreamResponseHandler(
}

override fun onResponseBody(stream: HttpStream, bodyBytesIn: Buffer): Int {
lock.withLock { crtStream = stream }
val isCancelled = lock.withLock {
crtStream = stream
cancelled
}

// short circuit, stop buffering data and discard remaining incoming bytes
if (isCancelled) return bodyBytesIn.len

// we should have created a response channel if we expected a body
val sdkRespChan = checkNotNull(sdkBody) { "unexpected response body" }
Expand All @@ -134,10 +145,6 @@ internal class SdkStreamResponseHandler(
streamCompleted = true
}

// release it back to the pool, this is safe to do now since the body (and any other response data)
// has been copied to buffers we own by now
conn.close()

// close the body channel
if (errorCode != 0) {
val errorDescription = CRT.errorString(errorCode)
Expand All @@ -162,13 +169,19 @@ internal class SdkStreamResponseHandler(
internal fun complete() {
// We have no way of cancelling the stream, we have to drive it to exhaustion OR close the connection.
// At this point we know it's safe to release resources so if the stream hasn't completed yet
// we forcefully close the connection. This can happen when the stream's window is full and it's waiting
// we forcefully shutdown the connection. This can happen when the stream's window is full and it's waiting
// on the window to be incremented to proceed (i.e. the user didn't consume the stream for whatever reason
// and more data is pending arrival).
val forceClose = lock.withLock { !streamCompleted }
// and more data is pending arrival). It can also happen if the coroutine for this request is cancelled
// before onResponseComplete fires.
lock.withLock {
val forceClose = !streamCompleted

if (forceClose) {
logger.trace { "stream did not complete before job, forcing connection shutdown! handler=$this; conn=$conn; stream=$crtStream" }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment

Perhaps this should be a higher level message, like warn. From past experience it is often useful to know when partial results where received.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't necessarily mean partial results though. Technically it just means the coroutine job completed before consuming the entire response which could happen because of an exception or simply because the consumer doesn't read the entire body.

If we actually had a stream cancellation API then we wouldn't need to force close the connection which would make this less heavy handed.

I don't think I'd make it warn because in the case of an exception they'll be forced to handle it anyway and the exception is your signal. In the case of explicitly not consuming the entire body it's not a warning since the consumer is choosing it. Perhaps debug would be better than trace though?

conn.shutdown()
cancelled = true
}

if (forceClose) {
conn.shutdown()
conn.close()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ class SdkStreamResponseHandlerTest {
assertEquals(HttpStatusCode.OK, resp.status)

assertTrue(resp.body is HttpBody.Empty)
handler.onResponseComplete(stream, 0)

assertFalse(mockConn.isClosed)
handler.onResponseComplete(stream, 0)
handler.complete()
assertTrue(mockConn.isClosed)
}

Expand All @@ -65,7 +66,6 @@ class SdkStreamResponseHandlerTest {

val resp = handler.waitForResponse()
assertEquals(HttpStatusCode.OK, resp.status)
assertTrue(mockConn.isClosed)
}

@Test
Expand All @@ -80,8 +80,6 @@ class SdkStreamResponseHandlerTest {
assertFails {
handler.waitForResponse()
}

assertTrue(mockConn.isClosed)
}

@Test
Expand All @@ -107,7 +105,6 @@ class SdkStreamResponseHandlerTest {

assertFalse(mockConn.isClosed)
handler.onResponseComplete(stream, 0)
assertTrue(mockConn.isClosed)
assertTrue(respChan.isClosedForWrite)
}

Expand All @@ -134,7 +131,6 @@ class SdkStreamResponseHandlerTest {
assertTrue(resp.body is HttpBody.Streaming)
val respChan = (resp.body as HttpBody.Streaming).readFrom()

assertTrue(mockConn.isClosed)
assertTrue(respChan.isClosedForWrite)

assertEquals(data, respChan.readRemaining().decodeToString())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import aws.smithy.kotlin.runtime.http.request.url
import aws.smithy.kotlin.runtime.http.response.complete
import aws.smithy.kotlin.runtime.http.sdkHttpClient
import aws.smithy.kotlin.runtime.httptest.TestWithLocalServer
import aws.smithy.kotlin.runtime.testing.IgnoreWindows
import io.ktor.application.*
import io.ktor.response.*
import io.ktor.routing.*
Expand Down Expand Up @@ -71,7 +70,6 @@ class AsyncStressTest : TestWithLocalServer() {
}
}

@IgnoreWindows("https://github.com/awslabs/aws-sdk-kotlin/issues/413")
@OptIn(ExperimentalTime::class)
@Test
fun testStreamNotConsumed() = runSuspendTest {
Expand Down