Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix request cancellation issue in the AWS CRT-based S3 client that co… #4955

Merged
merged 4 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS CRT-based S3 client",
"contributor": "",
"description": "Fixed memory leak issue when a request was cancelled in the AWS CRT-based S3 client."
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,8 @@
<!-- Allow private field declaration before public, to have correct initialization order -->
<suppress checks="DeclarationOrder"
files=".*SdkAdvancedClientOption\.java$"/>

<!-- Ignore usage of S3MetaRequest in S3MetaRequestWrapper. !-->
<suppress checks="Regexp"
files="software.amazon.awssdk.services.s3.internal.crt.S3MetaRequestWrapper.java"/>
</suppressions>
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,14 @@
<property name="ignoreComments" value="true"/>
</module>

<!-- Checks that we don't use S3MetaRequest -->
<module name="Regexp">
<property name="format" value="\bS3MetaRequest\b"/>
<property name="illegalPattern" value="true"/>
<property name="message" value="Don't use S3MetaRequest directly. Use software.amazon.awssdk.services.s3.internal.crt.S3MetaRequestWrapper instead"/>
<property name="ignoreComments" value="true"/>
</module>

<!-- Checks that we don't implement AutoCloseable/Closeable -->
<module name="Regexp">
<property name="format" value="(class|interface).*(implements|extends).*[^\w](Closeable|AutoCloseable)[^\w]"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import software.amazon.awssdk.crt.s3.ResumeToken;
import software.amazon.awssdk.crt.s3.S3MetaRequest;
import software.amazon.awssdk.services.s3.internal.crt.S3MetaRequestPauseObservable;
import software.amazon.awssdk.services.s3.internal.crt.S3MetaRequestWrapper;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;
import software.amazon.awssdk.transfer.s3.internal.model.CrtFileUpload;
import software.amazon.awssdk.transfer.s3.internal.progress.DefaultTransferProgressSnapshot;
Expand All @@ -53,7 +54,7 @@ class CrtFileUploadTest {
private static final int NUM_OF_PARTS_COMPLETED = 5;
private static final long PART_SIZE_IN_BYTES = 8 * MB;
private static final String MULTIPART_UPLOAD_ID = "someId";
private S3MetaRequest metaRequest;
private S3MetaRequestPauseObservable observable;
private static FileSystem fileSystem;
private static File file;
private static ResumeToken token;
Expand All @@ -77,7 +78,7 @@ public static void tearDown() throws IOException {

@BeforeEach
void setUpBeforeEachTest() {
metaRequest = Mockito.mock(S3MetaRequest.class);
observable = Mockito.mock(S3MetaRequestPauseObservable.class);
}

@Test
Expand All @@ -102,17 +103,13 @@ void pause_futureCompleted_shouldReturnNormally() {
.sdkResponse(putObjectResponse)
.transferredBytes(0L)
.build());
S3MetaRequestPauseObservable observable = new S3MetaRequestPauseObservable();

UploadFileRequest request = uploadFileRequest();

CrtFileUpload fileUpload =
new CrtFileUpload(future, transferProgress, observable, request);

observable.subscribe(metaRequest);

ResumableFileUpload resumableFileUpload = fileUpload.pause();
Mockito.verify(metaRequest, Mockito.never()).pause();
Mockito.verify(observable, Mockito.never()).pause();
assertThat(resumableFileUpload.totalParts()).isEmpty();
assertThat(resumableFileUpload.partSizeInBytes()).isEmpty();
assertThat(resumableFileUpload.multipartUploadId()).isEmpty();
Expand All @@ -130,10 +127,7 @@ void pauseTwice_shouldReturnTheSame() {
.transferredBytes(1000L)
.build());
UploadFileRequest request = uploadFileRequest();

S3MetaRequestPauseObservable observable = new S3MetaRequestPauseObservable();
when(metaRequest.pause()).thenReturn(token);
observable.subscribe(metaRequest);
when(observable.pause()).thenReturn(token);

CrtFileUpload fileUpload =
new CrtFileUpload(future, transferProgress, observable, request);
Expand All @@ -154,10 +148,8 @@ void pause_crtThrowException_shouldPropogate() {
.build());
UploadFileRequest request = uploadFileRequest();

S3MetaRequestPauseObservable observable = new S3MetaRequestPauseObservable();
CrtRuntimeException exception = new CrtRuntimeException("exception");
when(metaRequest.pause()).thenThrow(exception);
observable.subscribe(metaRequest);
when(observable.pause()).thenThrow(exception);

CrtFileUpload fileUpload =
new CrtFileUpload(future, transferProgress, observable, request);
Expand All @@ -173,17 +165,14 @@ void pause_futureNotComplete_shouldPause() {
when(transferProgress.snapshot()).thenReturn(DefaultTransferProgressSnapshot.builder()
.transferredBytes(0L)
.build());
S3MetaRequestPauseObservable observable = new S3MetaRequestPauseObservable();
when(metaRequest.pause()).thenReturn(token);
when(observable.pause()).thenReturn(token);
UploadFileRequest request = uploadFileRequest();

CrtFileUpload fileUpload =
new CrtFileUpload(future, transferProgress, observable, request);

observable.subscribe(metaRequest);

ResumableFileUpload resumableFileUpload = fileUpload.pause();
Mockito.verify(metaRequest).pause();
Mockito.verify(observable).pause();
assertThat(resumableFileUpload.totalParts()).hasValue(TOTAL_PARTS);
assertThat(resumableFileUpload.partSizeInBytes()).hasValue(PART_SIZE_IN_BYTES);
assertThat(resumableFileUpload.multipartUploadId()).hasValue(MULTIPART_UPLOAD_ID);
Expand All @@ -204,17 +193,14 @@ void pause_singlePart_shouldPause() {
.sdkResponse(putObjectResponse)
.transferredBytes(0L)
.build());
S3MetaRequestPauseObservable observable = new S3MetaRequestPauseObservable();
when(metaRequest.pause()).thenThrow(new CrtRuntimeException(6));
when(observable.pause()).thenThrow(new CrtRuntimeException(6));
UploadFileRequest request = uploadFileRequest();

CrtFileUpload fileUpload =
new CrtFileUpload(future, transferProgress, observable, request);

observable.subscribe(metaRequest);

ResumableFileUpload resumableFileUpload = fileUpload.pause();
Mockito.verify(metaRequest).pause();
Mockito.verify(observable).pause();
assertThat(resumableFileUpload.totalParts()).isEmpty();
assertThat(resumableFileUpload.partSizeInBytes()).isEmpty();
assertThat(resumableFileUpload.multipartUploadId()).isEmpty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import software.amazon.awssdk.crt.s3.ResumeToken;
import software.amazon.awssdk.crt.s3.S3Client;
import software.amazon.awssdk.crt.s3.S3ClientOptions;
import software.amazon.awssdk.crt.s3.S3MetaRequest;
import software.amazon.awssdk.crt.s3.S3MetaRequestOptions;
import software.amazon.awssdk.http.Header;
import software.amazon.awssdk.http.SdkHttpExecutionAttributes;
Expand Down Expand Up @@ -133,10 +132,12 @@ public CompletableFuture<Void> execute(AsyncExecuteRequest asyncRequest) {
URI uri = asyncRequest.request().getUri();
HttpRequest httpRequest = toCrtRequest(asyncRequest);
SdkHttpExecutionAttributes httpExecutionAttributes = asyncRequest.httpExecutionAttributes();
CompletableFuture<S3MetaRequestWrapper> s3MetaRequestFuture = new CompletableFuture<>();
S3CrtResponseHandlerAdapter responseHandler =
new S3CrtResponseHandlerAdapter(executeFuture,
asyncRequest.responseHandler(),
httpExecutionAttributes.getAttribute(CRT_PROGRESS_LISTENER));
httpExecutionAttributes.getAttribute(CRT_PROGRESS_LISTENER),
s3MetaRequestFuture);

S3MetaRequestOptions.MetaRequestType requestType = requestType(asyncRequest);

Expand All @@ -160,16 +161,19 @@ public CompletableFuture<Void> execute(AsyncExecuteRequest asyncRequest) {
.withRequestFilePath(requestFilePath)
.withSigningConfig(signingConfig);

S3MetaRequest s3MetaRequest = crtS3Client.makeMetaRequest(requestOptions);
S3MetaRequestPauseObservable observable =
httpExecutionAttributes.getAttribute(METAREQUEST_PAUSE_OBSERVABLE);
try {
S3MetaRequestWrapper requestWrapper = new S3MetaRequestWrapper(crtS3Client.makeMetaRequest(requestOptions));
s3MetaRequestFuture.complete(requestWrapper);

responseHandler.metaRequest(s3MetaRequest);
S3MetaRequestPauseObservable observable =
httpExecutionAttributes.getAttribute(METAREQUEST_PAUSE_OBSERVABLE);

if (observable != null) {
observable.subscribe(s3MetaRequest);
if (observable != null) {
observable.subscribe(requestWrapper);
}
} finally {
signingConfig.close();
}
closeResourceCallback(executeFuture, s3MetaRequest, responseHandler, signingConfig);

return executeFuture;
}
Expand Down Expand Up @@ -215,23 +219,6 @@ private static S3MetaRequestOptions.MetaRequestType requestType(AsyncExecuteRequ
return S3MetaRequestOptions.MetaRequestType.DEFAULT;
}

private static void closeResourceCallback(CompletableFuture<Void> executeFuture,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this to S3CrtResponseHandlerAdapter because I figured it makes more sense to manage the closure of the S3MetaRequest in one place.

S3MetaRequest s3MetaRequest,
S3CrtResponseHandlerAdapter responseHandler,
AwsSigningConfig signingConfig) {
executeFuture.whenComplete((r, t) -> {
if (executeFuture.isCancelled()) {
log.debug(() -> "The request is cancelled, cancelling meta request");
responseHandler.cancelRequest();
s3MetaRequest.cancel();
signingConfig.close();
} else {
s3MetaRequest.close();
signingConfig.close();
}
});
}

private static HttpRequest toCrtRequest(AsyncExecuteRequest asyncRequest) {
SdkHttpRequest sdkRequest = asyncRequest.request();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@
import software.amazon.awssdk.crt.CRT;
import software.amazon.awssdk.crt.http.HttpHeader;
import software.amazon.awssdk.crt.s3.S3FinishedResponseContext;
import software.amazon.awssdk.crt.s3.S3MetaRequest;
import software.amazon.awssdk.crt.s3.S3MetaRequestProgress;
import software.amazon.awssdk.crt.s3.S3MetaRequestResponseHandler;
import software.amazon.awssdk.http.SdkCancellationException;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
import software.amazon.awssdk.utils.Logger;
Expand All @@ -46,20 +44,44 @@ public final class S3CrtResponseHandlerAdapter implements S3MetaRequestResponseH
private final SimplePublisher<ByteBuffer> responsePublisher = new SimplePublisher<>();

private final SdkHttpResponse.Builder initialHeadersResponse = SdkHttpResponse.builder();
private volatile S3MetaRequest metaRequest;
private final CompletableFuture<S3MetaRequestWrapper> metaRequestFuture;

private final PublisherListener<S3MetaRequestProgress> progressListener;

private volatile boolean responseHandlingInitiated;

public S3CrtResponseHandlerAdapter(CompletableFuture<Void> executeFuture,
SdkAsyncHttpResponseHandler responseHandler,
PublisherListener<S3MetaRequestProgress> progressListener) {
PublisherListener<S3MetaRequestProgress> progressListener,
CompletableFuture<S3MetaRequestWrapper> metaRequestFuture) {
this.resultFuture = executeFuture;
this.metaRequestFuture = metaRequestFuture;

resultFuture.whenComplete((r, t) -> {
S3MetaRequestWrapper s3MetaRequest = s3MetaRequest();
if (s3MetaRequest == null) {
return;
}

if (t != null) {
s3MetaRequest.cancel();
}
s3MetaRequest.close();
});

this.responseHandler = responseHandler;
this.progressListener = progressListener == null ? new NoOpPublisherListener() : progressListener;
}

private S3MetaRequestWrapper s3MetaRequest() {
if (!metaRequestFuture.isDone()) {
return null;
}

S3MetaRequestWrapper s3MetaRequest = metaRequestFuture.join();
return s3MetaRequest;
}

@Override
public void onResponseHeaders(int statusCode, HttpHeader[] headers) {
// Note, we cannot call responseHandler.onHeaders() here because the response status code and headers may not represent
Expand Down Expand Up @@ -87,7 +109,13 @@ public int onResponseBody(ByteBuffer bodyBytesIn, long objectRangeStart, long ob
return;
}

metaRequest.incrementReadWindow(bytesReceived);
if (s3MetaRequest() == null) {
// should not happen
failResponseHandlerAndFuture(SdkClientException.create("Unexpected exception occurred: s3metaRequest is not "
+ "initialized yet"));
return;
}
s3MetaRequest().incrementReadWindow(bytesReceived);
zoewangg marked this conversation as resolved.
Show resolved Hide resolved
});

// Returning 0 to disable flow control because we manually increase read window above
Expand Down Expand Up @@ -115,22 +143,10 @@ private void onSuccessfulResponseComplete() {
return;
}
this.progressListener.subscriberOnComplete();
completeFutureAndCloseRequest();
resultFuture.complete(null);
});
}

private void completeFutureAndCloseRequest() {
resultFuture.complete(null);
runAndLogError(log.logger(), "Exception thrown in S3MetaRequest#close, ignoring",
() -> metaRequest.close());
}

public void cancelRequest() {
SdkCancellationException sdkClientException =
new SdkCancellationException("request is cancelled");
failResponseHandlerAndFuture(sdkClientException);
}

private void handleError(S3FinishedResponseContext context) {
int crtCode = context.getErrorCode();
HttpHeader[] headers = context.getErrorHeaders();
Expand Down Expand Up @@ -168,27 +184,21 @@ private void onErrorResponseComplete(byte[] errorPayload) {
failResponseHandlerAndFuture(throwable);
return null;
}
completeFutureAndCloseRequest();
resultFuture.complete(null);
return null;
});
}

private void failResponseHandlerAndFuture(Throwable exception) {
resultFuture.completeExceptionally(exception);
runAndLogError(log.logger(), "Exception thrown in SdkAsyncHttpResponseHandler#onError, ignoring",
() -> responseHandler.onError(exception));
runAndLogError(log.logger(), "Exception thrown in S3MetaRequest#close, ignoring",
() -> metaRequest.close());
resultFuture.completeExceptionally(exception);
}

private static boolean isErrorResponse(int responseStatus) {
return responseStatus != 0;
}

public void metaRequest(S3MetaRequest s3MetaRequest) {
metaRequest = s3MetaRequest;
}

@Override
public void onProgress(S3MetaRequestProgress progress) {
this.progressListener.subscriberOnNext(progress);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,24 @@
import java.util.function.Function;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.crt.s3.ResumeToken;
import software.amazon.awssdk.crt.s3.S3MetaRequest;

/**
* An observable that notifies the observer {@link S3CrtAsyncHttpClient} to pause the request.
*/
@SdkInternalApi
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class should've been a protected API. I will create a separate PR to fix it.

public class S3MetaRequestPauseObservable {

private final Function<S3MetaRequest, ResumeToken> pause;
private volatile S3MetaRequest request;
private final Function<S3MetaRequestWrapper, ResumeToken> pause;
private volatile S3MetaRequestWrapper request;

public S3MetaRequestPauseObservable() {
this.pause = S3MetaRequest::pause;
this.pause = S3MetaRequestWrapper::pause;
}

/**
* Subscribe {@link S3MetaRequest} to be potentially paused later.
* Subscribe {@link S3MetaRequestWrapper} to be potentially paused later.
*/
public void subscribe(S3MetaRequest request) {
public void subscribe(S3MetaRequestWrapper request) {
this.request = request;
}

Expand Down
Loading
Loading