Skip to content

Commit

Permalink
fix(rt): enforce only once shutdown logic for crt engine connections (#…
Browse files Browse the repository at this point in the history
…497)

Fixes the segfault that can happen when an exception is handled twice
leading to a connection being closed after it has been free'd. This
change refactors the handling of the connection close logic to be
handled in a single place regardless of why the connection is being
closed.
  • Loading branch information
aajtodd authored Jan 11, 2022
1 parent f123c6b commit df3d15a
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,36 +76,30 @@ public class CrtHttpEngine(public val config: CrtHttpEngineConfig) : HttpClientE
override suspend fun roundTrip(request: HttpRequest): HttpCall {
val callContext = callContext()
val manager = getManagerForUri(request.uri)

// LIFETIME: connection will be released back to the pool/manager when
// the response completes OR on exception (both handled by the completion handler registered on the stream
// handler)
val conn = withTimeoutOrNull(config.connectionAcquireTimeout) {
manager.acquireConnection()
} ?: throw ClientException("timed out waiting for an HTTP connection to be acquired from the pool")

try {
val reqTime = Instant.now()
val engineRequest = request.toCrtRequest(callContext)

// LIFETIME: connection will be released back to the pool/manager when
// the response completes OR on exception
val respHandler = SdkStreamResponseHandler(conn)
callContext.job.invokeOnCompletion {
// ensures the stream is driven to completion regardless of what the downstream consumer does
respHandler.complete()
}

val stream = conn.makeRequest(engineRequest, respHandler)
stream.activate()

val resp = respHandler.waitForResponse()

return HttpCall(request, resp, reqTime, Instant.now(), callContext)
} catch (ex: Exception) {
try {
manager.releaseConnection(conn)
} catch (ex2: Exception) {
ex.addSuppressed(ex2)
}
throw ex
val respHandler = SdkStreamResponseHandler(conn)
callContext.job.invokeOnCompletion {
logger.trace { "completing handler; cause=$it" }
// ensures the stream is driven to completion regardless of what the downstream consumer does
respHandler.complete()
}

val reqTime = Instant.now()
val engineRequest = request.toCrtRequest(callContext)

val stream = conn.makeRequest(engineRequest, respHandler)
stream.activate()

val resp = respHandler.waitForResponse()

return HttpCall(request, resp, reqTime, Instant.now(), callContext)
}

override fun shutdown() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ public class CrtHttpEngineConfig private constructor(builder: Builder) : HttpCli
* The default engine config. Most clients should use this.
*/
public val Default: CrtHttpEngineConfig = CrtHttpEngineConfig(Builder())

public operator fun invoke(block: Builder.() -> Unit): CrtHttpEngineConfig =
Builder().apply(block).build()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import aws.smithy.kotlin.runtime.http.*
import aws.smithy.kotlin.runtime.http.HeadersBuilder
import aws.smithy.kotlin.runtime.http.response.HttpResponse
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
import aws.smithy.kotlin.runtime.logging.Logger
import kotlinx.atomicfu.locks.reentrantLock
import kotlinx.atomicfu.locks.withLock
import kotlinx.coroutines.ExperimentalCoroutinesApi
Expand All @@ -30,13 +31,17 @@ internal class SdkStreamResponseHandler(
// There is no great way to do that currently without either (1) closing the connection or (2) throwing an
// exception from a callback such that AWS_OP_ERROR is returned. Wait for HttpStream to have explicit cancellation

private val logger = Logger.getLogger<SdkStreamResponseHandler>()
private val responseReady = Channel<HttpResponse>(1)
private val headers = HeadersBuilder()

private var sdkBody: BufferedReadChannel? = null

private val lock = reentrantLock()
private val lock = reentrantLock() // protects crtStream and cancelled state
private var crtStream: HttpStream? = null
// if the (coroutine) job is completed before the stream's onResponseComplete callback is
// invoked (for any reason) we consider the stream "cancelled"
private var cancelled = false

private val Int.isMainHeadersBlock: Boolean
get() = when (this) {
Expand Down Expand Up @@ -115,7 +120,13 @@ internal class SdkStreamResponseHandler(
}

override fun onResponseBody(stream: HttpStream, bodyBytesIn: Buffer): Int {
lock.withLock { crtStream = stream }
val isCancelled = lock.withLock {
crtStream = stream
cancelled
}

// short circuit, stop buffering data and discard remaining incoming bytes
if (isCancelled) return bodyBytesIn.len

// we should have created a response channel if we expected a body
val sdkRespChan = checkNotNull(sdkBody) { "unexpected response body" }
Expand All @@ -134,10 +145,6 @@ internal class SdkStreamResponseHandler(
streamCompleted = true
}

// release it back to the pool, this is safe to do now since the body (and any other response data)
// has been copied to buffers we own by now
conn.close()

// close the body channel
if (errorCode != 0) {
val errorDescription = CRT.errorString(errorCode)
Expand All @@ -162,13 +169,19 @@ internal class SdkStreamResponseHandler(
internal fun complete() {
// We have no way of cancelling the stream, we have to drive it to exhaustion OR close the connection.
// At this point we know it's safe to release resources so if the stream hasn't completed yet
// we forcefully close the connection. This can happen when the stream's window is full and it's waiting
// we forcefully shutdown the connection. This can happen when the stream's window is full and it's waiting
// on the window to be incremented to proceed (i.e. the user didn't consume the stream for whatever reason
// and more data is pending arrival).
val forceClose = lock.withLock { !streamCompleted }
// and more data is pending arrival). It can also happen if the coroutine for this request is cancelled
// before onResponseComplete fires.
lock.withLock {
val forceClose = !streamCompleted

if (forceClose) {
logger.debug { "stream did not complete before job, forcing connection shutdown! handler=$this; conn=$conn; stream=$crtStream" }
conn.shutdown()
cancelled = true
}

if (forceClose) {
conn.shutdown()
conn.close()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ class SdkStreamResponseHandlerTest {
assertEquals(HttpStatusCode.OK, resp.status)

assertTrue(resp.body is HttpBody.Empty)
handler.onResponseComplete(stream, 0)

assertFalse(mockConn.isClosed)
handler.onResponseComplete(stream, 0)
handler.complete()
assertTrue(mockConn.isClosed)
}

Expand All @@ -65,7 +66,6 @@ class SdkStreamResponseHandlerTest {

val resp = handler.waitForResponse()
assertEquals(HttpStatusCode.OK, resp.status)
assertTrue(mockConn.isClosed)
}

@Test
Expand All @@ -80,8 +80,6 @@ class SdkStreamResponseHandlerTest {
assertFails {
handler.waitForResponse()
}

assertTrue(mockConn.isClosed)
}

@Test
Expand All @@ -107,7 +105,6 @@ class SdkStreamResponseHandlerTest {

assertFalse(mockConn.isClosed)
handler.onResponseComplete(stream, 0)
assertTrue(mockConn.isClosed)
assertTrue(respChan.isClosedForWrite)
}

Expand All @@ -134,7 +131,6 @@ class SdkStreamResponseHandlerTest {
assertTrue(resp.body is HttpBody.Streaming)
val respChan = (resp.body as HttpBody.Streaming).readFrom()

assertTrue(mockConn.isClosed)
assertTrue(respChan.isClosedForWrite)

assertEquals(data, respChan.readRemaining().decodeToString())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import aws.smithy.kotlin.runtime.http.request.url
import aws.smithy.kotlin.runtime.http.response.complete
import aws.smithy.kotlin.runtime.http.sdkHttpClient
import aws.smithy.kotlin.runtime.httptest.TestWithLocalServer
import aws.smithy.kotlin.runtime.testing.IgnoreWindows
import io.ktor.application.*
import io.ktor.response.*
import io.ktor.routing.*
Expand Down Expand Up @@ -71,7 +70,6 @@ class AsyncStressTest : TestWithLocalServer() {
}
}

@IgnoreWindows("https://github.com/awslabs/aws-sdk-kotlin/issues/413")
@OptIn(ExperimentalTime::class)
@Test
fun testStreamNotConsumed() = runSuspendTest {
Expand Down

0 comments on commit df3d15a

Please sign in to comment.