Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable CRC32 for PUT for MultipartS3AsyncClient #4898

Merged
merged 7 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,31 @@
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import javax.crypto.KeyGenerator;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.core.ClientType;
import software.amazon.awssdk.core.ResponseBytes;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3IntegrationTestBase;
import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient;
import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient;
import software.amazon.awssdk.services.s3.model.CopyObjectResponse;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.MetadataDirective;
Expand All @@ -49,6 +57,7 @@
@Timeout(value = 3, unit = TimeUnit.MINUTES)
public class S3ClientMultiPartCopyIntegrationTest extends S3IntegrationTestBase {
private static final String BUCKET = temporaryBucketName(S3ClientMultiPartCopyIntegrationTest.class);
private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor();
private static final String ORIGINAL_OBJ = "test_file.dat";
private static final String COPIED_OBJ = "test_file_copy.dat";
private static final String ORIGINAL_OBJ_SPECIAL_CHARACTER = "original-special-chars-@$%";
Expand All @@ -70,7 +79,8 @@ public static void setUp() throws Exception {
.region(DEFAULT_REGION)
.credentialsProvider(CREDENTIALS_PROVIDER_CHAIN)
.overrideConfiguration(o -> o.addExecutionInterceptor(
new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC)))
new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC))
.addExecutionInterceptor(CAPTURING_INTERCEPTOR))
.multipartEnabled(true)
.build();
}
Expand All @@ -82,6 +92,11 @@ public static void teardown() throws Exception {
deleteBucketAndAllContents(BUCKET);
}

@BeforeEach
public void reset() {
CAPTURING_INTERCEPTOR.reset();
}

public static Stream<S3AsyncClient> s3AsyncClient() {
return Stream.of(s3MpuClient, s3CrtAsyncClient);
}
Expand Down Expand Up @@ -132,6 +147,8 @@ void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) {
.sseCustomerKeyMD5(b64KeyMd5),
AsyncRequestBody.fromBytes(originalContent)).join();

CAPTURING_INTERCEPTOR.reset();

CompletableFuture<CopyObjectResponse> future = s3AsyncClient.copyObject(c -> c
.sourceBucket(BUCKET)
.sourceKey(ORIGINAL_OBJ)
Expand All @@ -148,6 +165,7 @@ void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) {
CopyObjectResponse copyObjectResponse = future.join();
assertThat(copyObjectResponse.responseMetadata().requestId()).isNotNull();
assertThat(copyObjectResponse.sdkHttpResponse()).isNotNull();
verifyCopyContainsCrc32Header(s3AsyncClient);
}

private static byte[] generateSecretKey() {
Expand All @@ -166,6 +184,8 @@ private void createOriginalObject(byte[] originalContent, String originalKey) {
s3CrtAsyncClient.putObject(r -> r.bucket(BUCKET)
.key(originalKey),
AsyncRequestBody.fromBytes(originalContent)).join();

CAPTURING_INTERCEPTOR.reset();
davidh44 marked this conversation as resolved.
Show resolved Hide resolved
}

private void copyObject(String original, String destination, S3AsyncClient s3AsyncClient) {
Expand All @@ -178,6 +198,13 @@ private void copyObject(String original, String destination, S3AsyncClient s3Asy
CopyObjectResponse copyObjectResponse = future.join();
assertThat(copyObjectResponse.responseMetadata().requestId()).isNotNull();
assertThat(copyObjectResponse.sdkHttpResponse()).isNotNull();
verifyCopyContainsCrc32Header(s3AsyncClient);
}

private void verifyCopyContainsCrc32Header(S3AsyncClient s3AsyncClient) {
if (s3AsyncClient instanceof MultipartS3AsyncClient) {
assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");
}
}

private void validateCopiedObject(byte[] originalContent, String originalKey) {
Expand All @@ -192,4 +219,21 @@ public static byte[] randomBytes(long size) {
ThreadLocalRandom.current().nextBytes(bytes);
return bytes;
}

private static final class CapturingInterceptor implements ExecutionInterceptor {
private String checksumHeader;
@Override
davidh44 marked this conversation as resolved.
Show resolved Hide resolved
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
SdkHttpRequest sdkHttpRequest = context.httpRequest();
Map<String, List<String>> headers = sdkHttpRequest.headers();
String headerName1 = "x-amz-checksum-algorithm";
davidh44 marked this conversation as resolved.
Show resolved Hide resolved
if (headers.containsKey(headerName1)) {
checksumHeader = headers.get(headerName1).get(0);
}
}

public void reset() {
checksumHeader = null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,51 @@

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static software.amazon.awssdk.services.s3.model.ServerSideEncryption.AES256;
import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import javax.crypto.KeyGenerator;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.reactivestreams.Subscriber;
import software.amazon.awssdk.core.ClientType;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody;
import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3IntegrationTestBase;
import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.utils.ChecksumUtils;
import software.amazon.awssdk.utils.Md5Utils;

@Timeout(value = 30, unit = SECONDS)
public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTestBase {

private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientPutObjectIntegrationTest.class);
private static final String TEST_KEY = "testfile.dat";
private static final int OBJ_SIZE = 19 * 1024 * 1024;

private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor();
private static File testFile;
private static S3AsyncClient mpuS3Client;

Expand All @@ -65,8 +78,8 @@ public static void setup() throws Exception {
.builder()
.region(DEFAULT_REGION)
.credentialsProvider(CREDENTIALS_PROVIDER_CHAIN)
.overrideConfiguration(o -> o.addExecutionInterceptor(
new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC)))
.overrideConfiguration(o -> o.addExecutionInterceptor(new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC))
.addExecutionInterceptor(CAPTURING_INTERCEPTOR))
.multipartEnabled(true)
.build();
}
Expand All @@ -78,11 +91,18 @@ public static void teardown() throws Exception {
deleteBucketAndAllContents(TEST_BUCKET);
}

@BeforeEach
public void reset() {
CAPTURING_INTERCEPTOR.reset();
}

@Test
void putObject_fileRequestBody_objectSentCorrectly() throws Exception {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY),
ResponseTransformer.toInputStream());
Expand All @@ -98,6 +118,8 @@ void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception {
AsyncRequestBody body = AsyncRequestBody.fromBytes(bytes);
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY),
ResponseTransformer.toInputStream());
Expand All @@ -124,6 +146,8 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
}
}).get(30, SECONDS);

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY),
ResponseTransformer.toInputStream());
Expand All @@ -133,4 +157,63 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum);
}

@Test
void putObject_withSSECAndChecksum_objectSentCorrectly() throws Exception {
byte[] secretKey = generateSecretKey();
String b64Key = Base64.getEncoder().encodeToString(secretKey);
String b64KeyMd5 = Md5Utils.md5AsBase64(secretKey);

AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET)
.key(TEST_KEY)
.sseCustomerKey(b64Key)
.sseCustomerAlgorithm(AES256.name())
.sseCustomerKeyMD5(b64KeyMd5)
.checksumAlgorithm(ChecksumAlgorithm.CRC32),
body).join();

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET)
.key(TEST_KEY)
.sseCustomerKey(b64Key)
.sseCustomerAlgorithm(AES256.name())
.sseCustomerKeyMD5(b64KeyMd5),
ResponseTransformer.toInputStream());

assertThat(objContent.response().contentLength()).isEqualTo(testFile.length());
byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath()));
assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum);
}

private static byte[] generateSecretKey() {
KeyGenerator generator;
try {
generator = KeyGenerator.getInstance("AES");
generator.init(256, new SecureRandom());
return generator.generateKey().getEncoded();
} catch (Exception e) {
return null;
}
}

private static final class CapturingInterceptor implements ExecutionInterceptor {
private String checksumHeader;
@Override
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
SdkHttpRequest sdkHttpRequest = context.httpRequest();
Map<String, List<String>> headers = sdkHttpRequest.headers();
String headerName = "x-amz-sdk-checksum-algorithm";
if (headers.containsKey(headerName)) {
checksumHeader = headers.get(headerName).get(0);

System.out.println(headers);
}
}

public void reset() {
checksumHeader = null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@

package software.amazon.awssdk.services.s3.internal.multipart;

import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toCompleteMultipartUploadRequest;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload;
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Request;
import software.amazon.awssdk.services.s3.model.S3Response;
import software.amazon.awssdk.utils.Logger;
Expand Down Expand Up @@ -81,28 +81,13 @@ public int determinePartCount(long contentLength, long partSize) {
}

public CompletableFuture<CompleteMultipartUploadResponse> completeMultipartUpload(
RequestT request, String uploadId, CompletedPart[] parts) {
PutObjectRequest request, String uploadId, CompletedPart[] parts) {
log.debug(() -> String.format("Sending completeMultipartUploadRequest, uploadId: %s",
uploadId));
CompleteMultipartUploadRequest completeMultipartUploadRequest =
CompleteMultipartUploadRequest.builder()
.bucket(request.getValueForField("Bucket", String.class).get())
.key(request.getValueForField("Key", String.class).get())
.uploadId(uploadId)
.multipartUpload(CompletedMultipartUpload.builder()
.parts(parts)
.build())
.build();
return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest);
}

public CompletableFuture<CompleteMultipartUploadResponse> completeMultipartUpload(
RequestT request, String uploadId, AtomicReferenceArray<CompletedPart> completedParts) {
CompletedPart[] parts =
IntStream.range(0, completedParts.length())
.mapToObj(completedParts::get)
.toArray(CompletedPart[]::new);
return completeMultipartUpload(request, uploadId, parts);
CompleteMultipartUploadRequest completeMultipartUploadRequest = toCompleteMultipartUploadRequest(request, uploadId,
parts);
return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest);
}

public BiFunction<CompleteMultipartUploadResponse, Throwable, Void> handleExceptionOrResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.internal.UserAgentUtils;
import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm;
import software.amazon.awssdk.services.s3.model.CopyObjectRequest;
import software.amazon.awssdk.services.s3.model.CopyObjectResponse;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
Expand All @@ -36,8 +37,9 @@
import software.amazon.awssdk.utils.Validate;

/**
* An {@link S3AsyncClient} that automatically converts put, copy requests to their respective multipart call. Note: get is not
* yet supported.
* An {@link S3AsyncClient} that automatically converts PUT, COPY requests to their respective multipart call. CRC32 will be
* enabled for the PUT and COPY requests.
* Note: GET is not yet supported.
*
* @see MultipartConfiguration
*/
Expand All @@ -62,11 +64,13 @@ private MultipartS3AsyncClient(S3AsyncClient delegate, MultipartConfiguration mu

@Override
public CompletableFuture<PutObjectResponse> putObject(PutObjectRequest putObjectRequest, AsyncRequestBody requestBody) {
davidh44 marked this conversation as resolved.
Show resolved Hide resolved
putObjectRequest = putObjectRequest.toBuilder().checksumAlgorithm(ChecksumAlgorithm.CRC32).build();
davidh44 marked this conversation as resolved.
Show resolved Hide resolved
return mpuHelper.uploadObject(putObjectRequest, requestBody);
}

@Override
public CompletableFuture<CopyObjectResponse> copyObject(CopyObjectRequest copyObjectRequest) {
copyObjectRequest = copyObjectRequest.toBuilder().checksumAlgorithm(ChecksumAlgorithm.CRC32).build();
davidh44 marked this conversation as resolved.
Show resolved Hide resolved
return copyObjectHelper.copyObject(copyObjectRequest);
}

Expand Down
Loading
Loading