Skip to content

Commit

Permalink
Fixed a thread safety issue that could cause application to crash in … (
Browse files Browse the repository at this point in the history
#4839)

* Fixed a thread safety issue that could cause application to crash in the edge case in AWS CRT HTTP client

* Add tests
  • Loading branch information
zoewangg authored Jan 22, 2024
1 parent 10121d4 commit 6143fe5
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 25 deletions.
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AWSCRTHTTPClient-7b95a65.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS CRT HTTP Client",
"contributor": "",
"description": "Fixed a thread safety issue that could cause application to crash in the edge case where the SDK attempted to invoke `incrementWindow` after the stream is closed in AWS CRT HTTP Client."
}
6 changes: 6 additions & 0 deletions bom-internal/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>nl.jqno.equalsverifier</groupId>
<artifactId>equalsverifier</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion http-clients/aws-crt-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.concurrent.CompletableFuture;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.annotations.SdkTestInternalApi;
import software.amazon.awssdk.crt.CRT;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
Expand All @@ -46,19 +47,29 @@ public final class CrtResponseAdapter implements HttpStreamResponseHandler {
private final HttpClientConnection connection;
private final CompletableFuture<Void> completionFuture;
private final SdkAsyncHttpResponseHandler responseHandler;
private final SimplePublisher<ByteBuffer> responsePublisher = new SimplePublisher<>();
private final SimplePublisher<ByteBuffer> responsePublisher;

private final SdkHttpResponse.Builder responseBuilder;
private final ResponseHandlerHelper responseHandlerHelper;

private CrtResponseAdapter(HttpClientConnection connection,
CompletableFuture<Void> completionFuture,
SdkAsyncHttpResponseHandler responseHandler) {
this(connection, completionFuture, responseHandler, new SimplePublisher<>());
}


@SdkTestInternalApi
public CrtResponseAdapter(HttpClientConnection connection,
CompletableFuture<Void> completionFuture,
SdkAsyncHttpResponseHandler responseHandler,
SimplePublisher<ByteBuffer> simplePublisher) {
this.connection = Validate.paramNotNull(connection, "connection");
this.completionFuture = Validate.paramNotNull(completionFuture, "completionFuture");
this.responseHandler = Validate.paramNotNull(responseHandler, "responseHandler");
this.responseBuilder = SdkHttpResponse.builder();
this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, connection);
this.responsePublisher = simplePublisher;
}

public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnection crtConn,
Expand Down Expand Up @@ -95,9 +106,7 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) {
return;
}

if (!responseHandlerHelper.connectionClosed().get()) {
stream.incrementWindow(bodyBytesIn.length);
}
responseHandlerHelper.incrementWindow(stream, bodyBytesIn.length);
});

return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.concurrent.CompletableFuture;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.annotations.SdkTestInternalApi;
import software.amazon.awssdk.crt.CRT;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
Expand All @@ -42,7 +43,7 @@
public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpStreamResponseHandler {
private static final Logger log = Logger.loggerFor(InputStreamAdaptingHttpStreamResponseHandler.class);
private volatile AbortableInputStreamSubscriber inputStreamSubscriber;
private final SimplePublisher<ByteBuffer> simplePublisher = new SimplePublisher<>();
private final SimplePublisher<ByteBuffer> simplePublisher;

private final CompletableFuture<SdkHttpFullResponse> requestCompletionFuture;
private final HttpClientConnection crtConn;
Expand All @@ -52,10 +53,18 @@ public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpS

public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn,
CompletableFuture<SdkHttpFullResponse> requestCompletionFuture) {
this(crtConn, requestCompletionFuture, new SimplePublisher<>());
}

@SdkTestInternalApi
public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn,
CompletableFuture<SdkHttpFullResponse> requestCompletionFuture,
SimplePublisher<ByteBuffer> simplePublisher) {
this.crtConn = crtConn;
this.requestCompletionFuture = requestCompletionFuture;
this.responseBuilder = SdkHttpResponse.builder();
this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, crtConn);
this.simplePublisher = simplePublisher;
}

@Override
Expand Down Expand Up @@ -101,11 +110,8 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) {
failFutureAndCloseConnection(stream, failure);
return;
}

if (!responseHandlerHelper.connectionClosed().get()) {
// increment the window upon buffer consumption.
stream.incrementWindow(bodyBytesIn.length);
}
// increment the window upon buffer consumption.
responseHandlerHelper.incrementWindow(stream, bodyBytesIn.length);
});

// Window will be incremented after the subscriber consumes the data, returning 0 here to disable it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package software.amazon.awssdk.http.crt.internal.response;

import java.util.concurrent.atomic.AtomicBoolean;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpHeader;
Expand All @@ -30,14 +29,16 @@
*
* CRT connection will only be closed, i.e., not reused, in one of the following conditions:
* 1. 5xx server error OR
* 2. It fails to read the response.
* 2. It fails to read the response OR
* 3. the response stream is closed/aborted by the caller.
*/
@SdkInternalApi
public class ResponseHandlerHelper {

private final SdkHttpResponse.Builder responseBuilder;
private final HttpClientConnection connection;
private AtomicBoolean connectionClosed = new AtomicBoolean(false);
private boolean connectionClosed;
private final Object lock = new Object();

public ResponseHandlerHelper(SdkHttpResponse.Builder responseBuilder, HttpClientConnection connection) {
this.responseBuilder = responseBuilder;
Expand All @@ -57,20 +58,34 @@ public void onResponseHeaders(HttpStream stream, int responseStatusCode, int hea
* Release the connection back to the pool so that it can be reused.
*/
public void releaseConnection(HttpStream stream) {
if (connectionClosed.compareAndSet(false, true)) {
connection.close();
stream.close();
synchronized (lock) {
if (!connectionClosed) {
connectionClosed = true;
connection.close();
stream.close();
}
}
}

public void incrementWindow(HttpStream stream, int windowSize) {
synchronized (lock) {
if (!connectionClosed) {
stream.incrementWindow(windowSize);
}
}
}

/**
* Close the connection completely
*/
public void closeConnection(HttpStream stream) {
if (connectionClosed.compareAndSet(false, true)) {
connection.shutdown();
connection.close();
stream.close();
synchronized (lock) {
if (!connectionClosed) {
connectionClosed = true;
connection.shutdown();
connection.close();
stream.close();
}
}
}

Expand All @@ -82,8 +97,4 @@ public void cleanUpConnectionBasedOnStatusCode(HttpStream stream) {
releaseConnection(stream);
}
}

public AtomicBoolean connectionClosed() {
return connectionClosed;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,34 @@
package software.amazon.awssdk.http.crt.internal;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
import software.amazon.awssdk.crt.http.HttpHeader;
import software.amazon.awssdk.crt.http.HttpHeaderBlock;
import software.amazon.awssdk.crt.http.HttpStream;
import software.amazon.awssdk.crt.http.HttpStreamResponseHandler;
import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler;
import software.amazon.awssdk.utils.async.SimplePublisher;

@ExtendWith(MockitoExtension.class)
public abstract class BaseHttpStreamResponseHandlerTest {
Expand All @@ -44,10 +53,15 @@ public abstract class BaseHttpStreamResponseHandlerTest {
@Mock
HttpStream httpStream;

@Mock
SimplePublisher<ByteBuffer> simplePublisher;

HttpStreamResponseHandler responseHandler;

abstract HttpStreamResponseHandler responseHandler();

abstract HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher<ByteBuffer> simplePublisher);

@BeforeEach
public void setUp() {
requestFuture = new CompletableFuture<>();
Expand Down Expand Up @@ -113,6 +127,101 @@ void streamClosed_shouldNotIncreaseStreamWindow() throws InterruptedException {
verify(httpStream, never()).incrementWindow(anyInt());
}

@Test
void publisherWritesFutureFails_shouldShutdownConnection() {
SimplePublisher<ByteBuffer> simplePublisher = Mockito.mock(SimplePublisher.class);
CompletableFuture<Void> future = new CompletableFuture<>();
when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future);

HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher);
HttpHeader[] httpHeaders = getHttpHeaders();

handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
httpHeaders);
handler.onResponseHeadersDone(httpStream, 0);
handler.onResponseBody(httpStream,
RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8));
RuntimeException runtimeException = new RuntimeException();
future.completeExceptionally(runtimeException);

try {
requestFuture.join();
} catch (Exception e) {
// we don't verify here because it behaves differently in async and sync
}

verify(crtConn).shutdown();
verify(crtConn).close();
verify(httpStream).close();
verify(httpStream, never()).incrementWindow(anyInt());
}

@Test
void publisherWritesFutureCompletesAfterConnectionClosed_shouldNotInvokeIncrementWindow() {
CompletableFuture<Void> future = new CompletableFuture<>();
when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future);
when(simplePublisher.complete()).thenReturn(future);

HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher);


HttpHeader[] httpHeaders = getHttpHeaders();

handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
httpHeaders);
handler.onResponseHeadersDone(httpStream, 0);
handler.onResponseBody(httpStream,
RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8));
handler.onResponseComplete(httpStream, 0);
future.complete(null);

requestFuture.join();
verify(crtConn, never()).shutdown();
verify(crtConn).close();
verify(httpStream).close();
verify(httpStream, never()).incrementWindow(anyInt());
}

@Test
void publisherWritesFutureCompletesWhenConnectionClosed_shouldNotInvokeIncrementWindow() {
CompletableFuture<Void> future = new CompletableFuture<>();
when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future);
when(simplePublisher.complete()).thenReturn(future);

HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher);


HttpHeader[] httpHeaders = getHttpHeaders();

handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
httpHeaders);
handler.onResponseHeadersDone(httpStream, 0);
handler.onResponseBody(httpStream,
RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8));

// This tracker tracks which of the two operation completes first
AtomicInteger whenCompleteTracker = new AtomicInteger(0);
CompletableFuture<Void> onResponseComplete = CompletableFuture.runAsync(() -> handler.onResponseComplete(httpStream, 0))
.whenComplete((r, t) -> whenCompleteTracker.compareAndSet(0, 1));

CompletableFuture<Void> writeComplete = CompletableFuture.runAsync(() -> future.complete(null))
.whenComplete((r, t) -> whenCompleteTracker.compareAndSet(0, 2));
requestFuture.join();

CompletableFuture.allOf(onResponseComplete, writeComplete).join();

if (whenCompleteTracker.get() == 1) {
// onResponseComplete finishes first
verify(httpStream, never()).incrementWindow(anyInt());
} else {
verify(httpStream).incrementWindow(anyInt());
}

verify(crtConn, never()).shutdown();
verify(crtConn).close();
verify(httpStream).close();
}

static HttpHeader[] getHttpHeaders() {
HttpHeader[] httpHeaders = new HttpHeader[1];
httpHeaders[0] = new HttpHeader("Content-Length", "1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
import software.amazon.awssdk.http.crt.internal.response.CrtResponseAdapter;
import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler;
import software.amazon.awssdk.utils.async.SimplePublisher;

public class CrtResponseHandlerTest extends BaseHttpStreamResponseHandlerTest {

Expand All @@ -53,6 +54,15 @@ HttpStreamResponseHandler responseHandler() {
return CrtResponseAdapter.toCrtResponseHandler(crtConn, requestFuture, responseHandler);
}

@Override
HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher<ByteBuffer> simplePublisher) {
AsyncResponseHandler<Void> responseHandler = new AsyncResponseHandler<>((response,
executionAttributes) -> null, Function.identity(), new ExecutionAttributes());

responseHandler.prepare();
return new CrtResponseAdapter(crtConn, requestFuture, responseHandler, simplePublisher);
}

@Test
void publisherFailedToDeliverEvents_shouldShutDownConnection() {
SdkAsyncHttpResponseHandler responseHandler = new TestAsyncHttpResponseHandler();
Expand Down
Loading

0 comments on commit 6143fe5

Please sign in to comment.