Skip to content

Commit

Permalink
Enable CRC32 for PUT for MultipartS3AsyncClient (#4898)
Browse files Browse the repository at this point in the history
* Fix bug to allow MpuS3Client to PUT COPY with SSE-C and Checksum

* Enable CRC32 for Multipart PUT COPY

* Address comments

* add changelog and update checksum check

* Address comments

* update javadocs

* Add unit tests
  • Loading branch information
davidh44 authored Feb 8, 2024
1 parent 2f2e6e1 commit 6411744
Show file tree
Hide file tree
Showing 9 changed files with 397 additions and 37 deletions.
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AmazonS3-04d48e6.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"category": "Amazon S3",
"contributor": "",
"type": "bugfix",
"description": "Fix bug where PUT fails when using SSE-C with Checksum when using S3AsyncClient with multipart enabled. Enable CRC32 for putObject when using multipart client if checksum validation is not disabled and checksum is not set by user"
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
Expand All @@ -37,10 +39,16 @@
import software.amazon.awssdk.core.ClientType;
import software.amazon.awssdk.core.ResponseBytes;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3IntegrationTestBase;
import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient;
import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient;
import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm;
import software.amazon.awssdk.services.s3.model.CopyObjectResponse;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.MetadataDirective;
Expand All @@ -49,6 +57,7 @@
@Timeout(value = 3, unit = TimeUnit.MINUTES)
public class S3ClientMultiPartCopyIntegrationTest extends S3IntegrationTestBase {
private static final String BUCKET = temporaryBucketName(S3ClientMultiPartCopyIntegrationTest.class);
private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor();
private static final String ORIGINAL_OBJ = "test_file.dat";
private static final String COPIED_OBJ = "test_file_copy.dat";
private static final String ORIGINAL_OBJ_SPECIAL_CHARACTER = "original-special-chars-@$%";
Expand All @@ -70,7 +79,8 @@ public static void setUp() throws Exception {
.region(DEFAULT_REGION)
.credentialsProvider(CREDENTIALS_PROVIDER_CHAIN)
.overrideConfiguration(o -> o.addExecutionInterceptor(
new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC)))
new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC))
.addExecutionInterceptor(CAPTURING_INTERCEPTOR))
.multipartEnabled(true)
.build();
}
Expand Down Expand Up @@ -115,7 +125,7 @@ void copy_specialCharacters_hasSameContent(S3AsyncClient s3AsyncClient) {

@ParameterizedTest(autoCloseArguments = false)
@MethodSource("s3AsyncClient")
void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) {
void copy_withSSECAndChecksum_shouldSucceed(S3AsyncClient s3AsyncClient) {
byte[] originalContent = randomBytes(OBJ_SIZE);
byte[] secretKey = generateSecretKey();
String b64Key = Base64.getEncoder().encodeToString(secretKey);
Expand All @@ -132,6 +142,8 @@ void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) {
.sseCustomerKeyMD5(b64KeyMd5),
AsyncRequestBody.fromBytes(originalContent)).join();

CAPTURING_INTERCEPTOR.reset();

CompletableFuture<CopyObjectResponse> future = s3AsyncClient.copyObject(c -> c
.sourceBucket(BUCKET)
.sourceKey(ORIGINAL_OBJ)
Expand All @@ -143,11 +155,13 @@ void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) {
.copySourceSSECustomerKey(b64Key)
.copySourceSSECustomerKeyMD5(b64KeyMd5)
.destinationBucket(BUCKET)
.destinationKey(COPIED_OBJ));
.destinationKey(COPIED_OBJ)
.checksumAlgorithm(ChecksumAlgorithm.CRC32));

CopyObjectResponse copyObjectResponse = future.join();
assertThat(copyObjectResponse.responseMetadata().requestId()).isNotNull();
assertThat(copyObjectResponse.sdkHttpResponse()).isNotNull();
verifyCopyContainsCrc32Header(s3AsyncClient);
}

private static byte[] generateSecretKey() {
Expand Down Expand Up @@ -180,6 +194,12 @@ private void copyObject(String original, String destination, S3AsyncClient s3Asy
assertThat(copyObjectResponse.sdkHttpResponse()).isNotNull();
}

private void verifyCopyContainsCrc32Header(S3AsyncClient s3AsyncClient) {
if (s3AsyncClient instanceof MultipartS3AsyncClient) {
assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");
}
}

private void validateCopiedObject(byte[] originalContent, String originalKey) {
ResponseBytes<GetObjectResponse> copiedObject = s3.getObject(r -> r.bucket(BUCKET)
.key(originalKey),
Expand All @@ -192,4 +212,24 @@ public static byte[] randomBytes(long size) {
ThreadLocalRandom.current().nextBytes(bytes);
return bytes;
}

private static final class CapturingInterceptor implements ExecutionInterceptor {
private String checksumHeader;

@Override
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
SdkHttpRequest sdkHttpRequest = context.httpRequest();
Map<String, List<String>> headers = sdkHttpRequest.headers();
String checksumHeaderName = "x-amz-checksum-algorithm";
if (headers.containsKey(checksumHeaderName)) {
List<String> checksumHeaderVals = headers.get(checksumHeaderName);
assertThat(checksumHeaderVals).hasSize(1);
checksumHeader = checksumHeaderVals.get(0);
}
}

public void reset() {
checksumHeader = null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,56 +17,68 @@

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static software.amazon.awssdk.services.s3.model.ServerSideEncryption.AES256;
import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import javax.crypto.KeyGenerator;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.reactivestreams.Subscriber;
import software.amazon.awssdk.core.ClientType;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody;
import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3IntegrationTestBase;
import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.utils.ChecksumUtils;
import software.amazon.awssdk.utils.Md5Utils;

@Timeout(value = 30, unit = SECONDS)
public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTestBase {

private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientPutObjectIntegrationTest.class);
private static final String TEST_KEY = "testfile.dat";
private static final int OBJ_SIZE = 19 * 1024 * 1024;

private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor();
private static final byte[] CONTENT = RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset());
private static File testFile;
private static S3AsyncClient mpuS3Client;

@BeforeAll
public static void setup() throws Exception {
S3IntegrationTestBase.setUp();
S3IntegrationTestBase.createBucket(TEST_BUCKET);
byte[] CONTENT =
RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset());

testFile = File.createTempFile("SplittingPublisherTest", UUID.randomUUID().toString());
Files.write(testFile.toPath(), CONTENT);
mpuS3Client = S3AsyncClient
.builder()
.region(DEFAULT_REGION)
.credentialsProvider(CREDENTIALS_PROVIDER_CHAIN)
.overrideConfiguration(o -> o.addExecutionInterceptor(
new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC)))
.overrideConfiguration(o -> o.addExecutionInterceptor(new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC))
.addExecutionInterceptor(CAPTURING_INTERCEPTOR))
.multipartEnabled(true)
.build();
}
Expand All @@ -78,11 +90,18 @@ public static void teardown() throws Exception {
deleteBucketAndAllContents(TEST_BUCKET);
}

@BeforeEach
public void reset() {
CAPTURING_INTERCEPTOR.reset();
}

@Test
void putObject_fileRequestBody_objectSentCorrectly() throws Exception {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY),
ResponseTransformer.toInputStream());
Expand All @@ -98,6 +117,8 @@ void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception {
AsyncRequestBody body = AsyncRequestBody.fromBytes(bytes);
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY),
ResponseTransformer.toInputStream());
Expand All @@ -124,6 +145,8 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
}
}).get(30, SECONDS);

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY),
ResponseTransformer.toInputStream());
Expand All @@ -133,4 +156,102 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum);
}

@Test
void putObject_withSSECAndChecksum_objectSentCorrectly() throws Exception {
byte[] secretKey = generateSecretKey();
String b64Key = Base64.getEncoder().encodeToString(secretKey);
String b64KeyMd5 = Md5Utils.md5AsBase64(secretKey);

AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET)
.key(TEST_KEY)
.sseCustomerKey(b64Key)
.sseCustomerAlgorithm(AES256.name())
.sseCustomerKeyMD5(b64KeyMd5),
body).join();

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32");

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET)
.key(TEST_KEY)
.sseCustomerKey(b64Key)
.sseCustomerAlgorithm(AES256.name())
.sseCustomerKeyMD5(b64KeyMd5),
ResponseTransformer.toInputStream());

assertThat(objContent.response().contentLength()).isEqualTo(testFile.length());
byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath()));
assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum);
}

@Test
void putObject_withUserSpecifiedChecksumValue_objectSentCorrectly() throws Exception {
String sha1Val = calculateSHA1AsString();
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET)
.key(TEST_KEY)
.checksumSHA1(sha1Val),
body).join();

assertThat(CAPTURING_INTERCEPTOR.headers.get("x-amz-checksum-sha1")).contains(sha1Val);
assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isNull();

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY),
ResponseTransformer.toInputStream());

assertThat(objContent.response().contentLength()).isEqualTo(testFile.length());
byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath()));
assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum);
}

@Test
void putObject_withUserSpecifiedChecksumTypeOtherThanCrc32_shouldHonorChecksum() {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET)
.key(TEST_KEY)
.checksumAlgorithm(ChecksumAlgorithm.SHA1),
body).join();

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("SHA1");
}

private static String calculateSHA1AsString() throws Exception {
MessageDigest md = MessageDigest.getInstance("SHA-1");
md.update(CONTENT);
byte[] checksum = md.digest();
return Base64.getEncoder().encodeToString(checksum);
}

private static byte[] generateSecretKey() {
KeyGenerator generator;
try {
generator = KeyGenerator.getInstance("AES");
generator.init(256, new SecureRandom());
return generator.generateKey().getEncoded();
} catch (Exception e) {
return null;
}
}

private static final class CapturingInterceptor implements ExecutionInterceptor {
String checksumHeader;
Map<String, List<String>> headers;
@Override
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
SdkHttpRequest sdkHttpRequest = context.httpRequest();
headers = sdkHttpRequest.headers();
String checksumHeaderName = "x-amz-sdk-checksum-algorithm";
if (headers.containsKey(checksumHeaderName)) {
List<String> checksumHeaderVals = headers.get(checksumHeaderName);
assertThat(checksumHeaderVals).hasSize(1);
checksumHeader = checksumHeaderVals.get(0);
}
}

public void reset() {
checksumHeader = null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@

package software.amazon.awssdk.services.s3.internal.multipart;

import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toCompleteMultipartUploadRequest;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse;
import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload;
import software.amazon.awssdk.services.s3.model.CompletedPart;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Request;
import software.amazon.awssdk.services.s3.model.S3Response;
import software.amazon.awssdk.utils.Logger;
Expand Down Expand Up @@ -81,28 +81,13 @@ public int determinePartCount(long contentLength, long partSize) {
}

public CompletableFuture<CompleteMultipartUploadResponse> completeMultipartUpload(
RequestT request, String uploadId, CompletedPart[] parts) {
PutObjectRequest request, String uploadId, CompletedPart[] parts) {
log.debug(() -> String.format("Sending completeMultipartUploadRequest, uploadId: %s",
uploadId));
CompleteMultipartUploadRequest completeMultipartUploadRequest =
CompleteMultipartUploadRequest.builder()
.bucket(request.getValueForField("Bucket", String.class).get())
.key(request.getValueForField("Key", String.class).get())
.uploadId(uploadId)
.multipartUpload(CompletedMultipartUpload.builder()
.parts(parts)
.build())
.build();
return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest);
}

public CompletableFuture<CompleteMultipartUploadResponse> completeMultipartUpload(
RequestT request, String uploadId, AtomicReferenceArray<CompletedPart> completedParts) {
CompletedPart[] parts =
IntStream.range(0, completedParts.length())
.mapToObj(completedParts::get)
.toArray(CompletedPart[]::new);
return completeMultipartUpload(request, uploadId, parts);
CompleteMultipartUploadRequest completeMultipartUploadRequest = toCompleteMultipartUploadRequest(request, uploadId,
parts);
return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest);
}

public BiFunction<CompleteMultipartUploadResponse, Throwable, Void> handleExceptionOrResponse(
Expand Down
Loading

0 comments on commit 6411744

Please sign in to comment.