From 6411744f8e3e9c3b11a1383921ad4b8e4f822aee Mon Sep 17 00:00:00 2001 From: David Ho <70000000+davidh44@users.noreply.github.com> Date: Thu, 8 Feb 2024 15:56:22 -0800 Subject: [PATCH] Enable CRC32 for PUT for MultipartS3AsyncClient (#4898) * Fix bug to allow MpuS3Client to PUT COPY with SSE-C and Checksum * Enable CRC32 for Multipart PUT COPY * Address comments * add changelog and update checksum check * Address comments * update javadocs * Add unit tests --- .../next-release/bugfix-AmazonS3-04d48e6.json | 6 + .../S3ClientMultiPartCopyIntegrationTest.java | 46 +++++- ...ltipartClientPutObjectIntegrationTest.java | 133 +++++++++++++++- .../multipart/GenericMultipartHelper.java | 29 +--- .../multipart/MultipartS3AsyncClient.java | 44 +++++- .../multipart/SdkPojoConversionUtils.java | 10 +- .../codegen-resources/customization.config | 2 +- .../MultipartClientChecksumTest.java | 145 ++++++++++++++++++ .../multipart/SdkPojoConversionUtilsTest.java | 19 ++- 9 files changed, 397 insertions(+), 37 deletions(-) create mode 100644 .changes/next-release/bugfix-AmazonS3-04d48e6.json create mode 100644 services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartClientChecksumTest.java diff --git a/.changes/next-release/bugfix-AmazonS3-04d48e6.json b/.changes/next-release/bugfix-AmazonS3-04d48e6.json new file mode 100644 index 000000000000..ebb989119c2d --- /dev/null +++ b/.changes/next-release/bugfix-AmazonS3-04d48e6.json @@ -0,0 +1,6 @@ +{ + "category": "Amazon S3", + "contributor": "", + "type": "bugfix", + "description": "Fix bug where PUT fails when using SSE-C with Checksum when using S3AsyncClient with multipart enabled. Enable CRC32 for putObject when using multipart client if checksum validation is not disabled and checksum is not set by user" +} diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java index fc4f31b76b1a..4d942d942e7f 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3ClientMultiPartCopyIntegrationTest.java @@ -24,6 +24,8 @@ import java.nio.ByteBuffer; import java.security.SecureRandom; import java.util.Base64; +import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; @@ -37,10 +39,16 @@ import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3IntegrationTestBase; import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; +import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm; import software.amazon.awssdk.services.s3.model.CopyObjectResponse; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.MetadataDirective; @@ -49,6 +57,7 @@ @Timeout(value = 3, unit = TimeUnit.MINUTES) public class S3ClientMultiPartCopyIntegrationTest extends S3IntegrationTestBase { private static final String BUCKET = temporaryBucketName(S3ClientMultiPartCopyIntegrationTest.class); + private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor(); private static final String ORIGINAL_OBJ = "test_file.dat"; private static final String COPIED_OBJ = "test_file_copy.dat"; private static final String ORIGINAL_OBJ_SPECIAL_CHARACTER = "original-special-chars-@$%"; @@ -70,7 +79,8 @@ public static void setUp() throws Exception { .region(DEFAULT_REGION) .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) .overrideConfiguration(o -> o.addExecutionInterceptor( - new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC))) + new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC)) + .addExecutionInterceptor(CAPTURING_INTERCEPTOR)) .multipartEnabled(true) .build(); } @@ -115,7 +125,7 @@ void copy_specialCharacters_hasSameContent(S3AsyncClient s3AsyncClient) { @ParameterizedTest(autoCloseArguments = false) @MethodSource("s3AsyncClient") - void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) { + void copy_withSSECAndChecksum_shouldSucceed(S3AsyncClient s3AsyncClient) { byte[] originalContent = randomBytes(OBJ_SIZE); byte[] secretKey = generateSecretKey(); String b64Key = Base64.getEncoder().encodeToString(secretKey); @@ -132,6 +142,8 @@ void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) { .sseCustomerKeyMD5(b64KeyMd5), AsyncRequestBody.fromBytes(originalContent)).join(); + CAPTURING_INTERCEPTOR.reset(); + CompletableFuture future = s3AsyncClient.copyObject(c -> c .sourceBucket(BUCKET) .sourceKey(ORIGINAL_OBJ) @@ -143,11 +155,13 @@ void copy_ssecServerSideEncryption_shouldSucceed(S3AsyncClient s3AsyncClient) { .copySourceSSECustomerKey(b64Key) .copySourceSSECustomerKeyMD5(b64KeyMd5) .destinationBucket(BUCKET) - .destinationKey(COPIED_OBJ)); + .destinationKey(COPIED_OBJ) + .checksumAlgorithm(ChecksumAlgorithm.CRC32)); CopyObjectResponse copyObjectResponse = future.join(); assertThat(copyObjectResponse.responseMetadata().requestId()).isNotNull(); assertThat(copyObjectResponse.sdkHttpResponse()).isNotNull(); + verifyCopyContainsCrc32Header(s3AsyncClient); } private static byte[] generateSecretKey() { @@ -180,6 +194,12 @@ private void copyObject(String original, String destination, S3AsyncClient s3Asy assertThat(copyObjectResponse.sdkHttpResponse()).isNotNull(); } + private void verifyCopyContainsCrc32Header(S3AsyncClient s3AsyncClient) { + if (s3AsyncClient instanceof MultipartS3AsyncClient) { + assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32"); + } + } + private void validateCopiedObject(byte[] originalContent, String originalKey) { ResponseBytes copiedObject = s3.getObject(r -> r.bucket(BUCKET) .key(originalKey), @@ -192,4 +212,24 @@ public static byte[] randomBytes(long size) { ThreadLocalRandom.current().nextBytes(bytes); return bytes; } + + private static final class CapturingInterceptor implements ExecutionInterceptor { + private String checksumHeader; + + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + SdkHttpRequest sdkHttpRequest = context.httpRequest(); + Map> headers = sdkHttpRequest.headers(); + String checksumHeaderName = "x-amz-checksum-algorithm"; + if (headers.containsKey(checksumHeaderName)) { + List checksumHeaderVals = headers.get(checksumHeaderName); + assertThat(checksumHeaderVals).hasSize(1); + checksumHeader = checksumHeaderVals.get(0); + } + } + + public void reset() { + checksumHeader = null; + } + } } diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java index fa31b5453e5e..3e6811f69b3c 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java @@ -17,6 +17,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.services.s3.model.ServerSideEncryption.AES256; import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; import java.io.ByteArrayInputStream; @@ -24,23 +25,36 @@ import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.file.Files; +import java.security.MessageDigest; +import java.security.SecureRandom; +import java.util.Base64; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.UUID; +import javax.crypto.KeyGenerator; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.reactivestreams.Subscriber; import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3IntegrationTestBase; +import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.utils.ChecksumUtils; +import software.amazon.awssdk.utils.Md5Utils; @Timeout(value = 30, unit = SECONDS) public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTestBase { @@ -48,7 +62,8 @@ public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTest private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientPutObjectIntegrationTest.class); private static final String TEST_KEY = "testfile.dat"; private static final int OBJ_SIZE = 19 * 1024 * 1024; - + private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor(); + private static final byte[] CONTENT = RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset()); private static File testFile; private static S3AsyncClient mpuS3Client; @@ -56,17 +71,14 @@ public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTest public static void setup() throws Exception { S3IntegrationTestBase.setUp(); S3IntegrationTestBase.createBucket(TEST_BUCKET); - byte[] CONTENT = - RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset()); - testFile = File.createTempFile("SplittingPublisherTest", UUID.randomUUID().toString()); Files.write(testFile.toPath(), CONTENT); mpuS3Client = S3AsyncClient .builder() .region(DEFAULT_REGION) .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) - .overrideConfiguration(o -> o.addExecutionInterceptor( - new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC))) + .overrideConfiguration(o -> o.addExecutionInterceptor(new UserAgentVerifyingExecutionInterceptor("NettyNio", ClientType.ASYNC)) + .addExecutionInterceptor(CAPTURING_INTERCEPTOR)) .multipartEnabled(true) .build(); } @@ -78,11 +90,18 @@ public static void teardown() throws Exception { deleteBucketAndAllContents(TEST_BUCKET); } + @BeforeEach + public void reset() { + CAPTURING_INTERCEPTOR.reset(); + } + @Test void putObject_fileRequestBody_objectSentCorrectly() throws Exception { AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); + assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32"); + ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), ResponseTransformer.toInputStream()); @@ -98,6 +117,8 @@ void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception { AsyncRequestBody body = AsyncRequestBody.fromBytes(bytes); mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); + assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32"); + ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), ResponseTransformer.toInputStream()); @@ -124,6 +145,8 @@ public void subscribe(Subscriber s) { } }).get(30, SECONDS); + assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32"); + ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), ResponseTransformer.toInputStream()); @@ -133,4 +156,102 @@ public void subscribe(Subscriber s) { assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); } + @Test + void putObject_withSSECAndChecksum_objectSentCorrectly() throws Exception { + byte[] secretKey = generateSecretKey(); + String b64Key = Base64.getEncoder().encodeToString(secretKey); + String b64KeyMd5 = Md5Utils.md5AsBase64(secretKey); + + AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); + mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET) + .key(TEST_KEY) + .sseCustomerKey(b64Key) + .sseCustomerAlgorithm(AES256.name()) + .sseCustomerKeyMD5(b64KeyMd5), + body).join(); + + assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("CRC32"); + + ResponseInputStream objContent = + S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET) + .key(TEST_KEY) + .sseCustomerKey(b64Key) + .sseCustomerAlgorithm(AES256.name()) + .sseCustomerKeyMD5(b64KeyMd5), + ResponseTransformer.toInputStream()); + + assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); + byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + } + + @Test + void putObject_withUserSpecifiedChecksumValue_objectSentCorrectly() throws Exception { + String sha1Val = calculateSHA1AsString(); + AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); + mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET) + .key(TEST_KEY) + .checksumSHA1(sha1Val), + body).join(); + + assertThat(CAPTURING_INTERCEPTOR.headers.get("x-amz-checksum-sha1")).contains(sha1Val); + assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isNull(); + + ResponseInputStream objContent = + S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); + + assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); + byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + } + + @Test + void putObject_withUserSpecifiedChecksumTypeOtherThanCrc32_shouldHonorChecksum() { + AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); + mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET) + .key(TEST_KEY) + .checksumAlgorithm(ChecksumAlgorithm.SHA1), + body).join(); + + assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("SHA1"); + } + + private static String calculateSHA1AsString() throws Exception { + MessageDigest md = MessageDigest.getInstance("SHA-1"); + md.update(CONTENT); + byte[] checksum = md.digest(); + return Base64.getEncoder().encodeToString(checksum); + } + + private static byte[] generateSecretKey() { + KeyGenerator generator; + try { + generator = KeyGenerator.getInstance("AES"); + generator.init(256, new SecureRandom()); + return generator.generateKey().getEncoded(); + } catch (Exception e) { + return null; + } + } + + private static final class CapturingInterceptor implements ExecutionInterceptor { + String checksumHeader; + Map> headers; + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + SdkHttpRequest sdkHttpRequest = context.httpRequest(); + headers = sdkHttpRequest.headers(); + String checksumHeaderName = "x-amz-sdk-checksum-algorithm"; + if (headers.containsKey(checksumHeaderName)) { + List checksumHeaderVals = headers.get(checksumHeaderName); + assertThat(checksumHeaderVals).hasSize(1); + checksumHeader = checksumHeaderVals.get(0); + } + } + + public void reset() { + checksumHeader = null; + } + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java index 1d251ad69678..1906408a59b4 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java @@ -15,13 +15,13 @@ package software.amazon.awssdk.services.s3.internal.multipart; +import static software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils.toCompleteMultipartUploadRequest; + import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; -import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; -import java.util.stream.IntStream; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; @@ -29,8 +29,8 @@ import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; -import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.S3Request; import software.amazon.awssdk.services.s3.model.S3Response; import software.amazon.awssdk.utils.Logger; @@ -81,28 +81,13 @@ public int determinePartCount(long contentLength, long partSize) { } public CompletableFuture completeMultipartUpload( - RequestT request, String uploadId, CompletedPart[] parts) { + PutObjectRequest request, String uploadId, CompletedPart[] parts) { log.debug(() -> String.format("Sending completeMultipartUploadRequest, uploadId: %s", uploadId)); - CompleteMultipartUploadRequest completeMultipartUploadRequest = - CompleteMultipartUploadRequest.builder() - .bucket(request.getValueForField("Bucket", String.class).get()) - .key(request.getValueForField("Key", String.class).get()) - .uploadId(uploadId) - .multipartUpload(CompletedMultipartUpload.builder() - .parts(parts) - .build()) - .build(); - return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); - } - public CompletableFuture completeMultipartUpload( - RequestT request, String uploadId, AtomicReferenceArray completedParts) { - CompletedPart[] parts = - IntStream.range(0, completedParts.length()) - .mapToObj(completedParts::get) - .toArray(CompletedPart[]::new); - return completeMultipartUpload(request, uploadId, parts); + CompleteMultipartUploadRequest completeMultipartUploadRequest = toCompleteMultipartUploadRequest(request, uploadId, + parts); + return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); } public BiFunction handleExceptionOrResponse( diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java index 8b53099b8683..3de5c114193b 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -18,13 +18,19 @@ import java.util.concurrent.CompletableFuture; import java.util.function.Function; +import java.util.stream.Stream; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; import software.amazon.awssdk.core.ApiName; +import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3Configuration; import software.amazon.awssdk.services.s3.internal.UserAgentUtils; +import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm; import software.amazon.awssdk.services.s3.model.CopyObjectRequest; import software.amazon.awssdk.services.s3.model.CopyObjectResponse; import software.amazon.awssdk.services.s3.model.GetObjectRequest; @@ -36,8 +42,9 @@ import software.amazon.awssdk.utils.Validate; /** - * An {@link S3AsyncClient} that automatically converts put, copy requests to their respective multipart call. Note: get is not - * yet supported. + * An {@link S3AsyncClient} that automatically converts PUT, COPY requests to their respective multipart call. CRC32 will be + * enabled for the PUT and COPY requests, unless the the checksum is specified or checksum validation is disabled. + * Note: GET is not yet supported. * * @see MultipartConfiguration */ @@ -62,9 +69,40 @@ private MultipartS3AsyncClient(S3AsyncClient delegate, MultipartConfiguration mu @Override public CompletableFuture putObject(PutObjectRequest putObjectRequest, AsyncRequestBody requestBody) { + if (shouldEnableCrc32(putObjectRequest)) { + putObjectRequest = putObjectRequest.toBuilder().checksumAlgorithm(ChecksumAlgorithm.CRC32).build(); + } + return mpuHelper.uploadObject(putObjectRequest, requestBody); } + private boolean shouldEnableCrc32(PutObjectRequest putObjectRequest) { + return !checksumSetOnRequest(putObjectRequest) && checksumEnabledPerConfig(putObjectRequest); + } + + private boolean checksumSetOnRequest(PutObjectRequest putObjectRequest) { + if (putObjectRequest.checksumAlgorithm() != null) { + return true; + } + + return Stream.of("ChecksumCRC32", "ChecksumCRC32C", "ChecksumSHA1", "ChecksumSHA256") + .anyMatch(s -> putObjectRequest.getValueForField(s, String.class).isPresent()); + } + + private boolean checksumEnabledPerConfig(PutObjectRequest putObjectRequest) { + ExecutionAttributes executionAttributes = + putObjectRequest.overrideConfiguration().map(RequestOverrideConfiguration::executionAttributes).orElse(null); + + if (executionAttributes == null) { + return true; + } + + S3Configuration serviceConfiguration = + (S3Configuration) executionAttributes.getAttribute(AwsSignerExecutionAttribute.SERVICE_CONFIG); + + return serviceConfiguration == null || serviceConfiguration.checksumValidationEnabled(); + } + @Override public CompletableFuture copyObject(CopyObjectRequest copyObjectRequest) { return copyObjectHelper.copyObject(copyObjectRequest); @@ -73,6 +111,8 @@ public CompletableFuture copyObject(CopyObjectRequest copyOb @Override public CompletableFuture getObject( GetObjectRequest getObjectRequest, AsyncResponseTransformer asyncResponseTransformer) { + // TODO uncomment once implemented + // getObjectRequest = getObjectRequest.toBuilder().checksumMode(ChecksumMode.ENABLED).build(); throw new UnsupportedOperationException( "Multipart download is not yet supported. Instead use the CRT based S3 client for multipart download."); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java index 25fde18cadaf..b29f176e6fb5 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java @@ -25,6 +25,7 @@ import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CopyObjectRequest; @@ -56,9 +57,7 @@ private SdkPojoConversionUtils() { public static UploadPartRequest toUploadPartRequest(PutObjectRequest putObjectRequest, int partNumber, String uploadId) { UploadPartRequest.Builder builder = UploadPartRequest.builder(); - setSdkFields(builder, putObjectRequest, PUT_OBJECT_REQUEST_TO_UPLOAD_PART_FIELDS_TO_IGNORE); - return builder.uploadId(uploadId).partNumber(partNumber).build(); } @@ -69,6 +68,13 @@ public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(PutObj return builder.build(); } + public static CompleteMultipartUploadRequest toCompleteMultipartUploadRequest(PutObjectRequest putObjectRequest, + String uploadId, CompletedPart[] parts) { + CompleteMultipartUploadRequest.Builder builder = CompleteMultipartUploadRequest.builder(); + setSdkFields(builder, putObjectRequest); + return builder.uploadId(uploadId).multipartUpload(c -> c.parts(parts)).build(); + } + public static HeadObjectRequest toHeadObjectRequest(CopyObjectRequest copyObjectRequest) { // We can't set SdkFields directly because the fields in CopyObjectRequest do not match 100% with the ones in diff --git a/services/s3/src/main/resources/codegen-resources/customization.config b/services/s3/src/main/resources/codegen-resources/customization.config index 07c09308a2eb..1baddf167c98 100644 --- a/services/s3/src/main/resources/codegen-resources/customization.config +++ b/services/s3/src/main/resources/codegen-resources/customization.config @@ -240,7 +240,7 @@ "multipartCustomization": { "multipartConfigurationClass": "software.amazon.awssdk.services.s3.multipart.MultipartConfiguration", "multipartConfigMethodDoc": "Configuration for multipart operation of this client.", - "multipartEnableMethodDoc": "Enables automatic conversion of put and copy method to their equivalent multipart operation.", + "multipartEnableMethodDoc": "Enables automatic conversion of PUT and COPY methods to their equivalent multipart operation. CRC32 checksum will be enabled for PUT, unless the checksum is specified or checksum validation is disabled.", "contextParamEnabledKey": "S3AsyncClientDecorator.MULTIPART_ENABLED_KEY", "contextParamConfigKey": "S3AsyncClientDecorator.MULTIPART_CONFIGURATION_KEY" }, diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartClientChecksumTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartClientChecksumTest.java new file mode 100644 index 000000000000..351c1a750a60 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartClientChecksumTest.java @@ -0,0 +1,145 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.http.HttpExecuteResponse; +import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient; + +class MultipartClientChecksumTest { + private MockAsyncHttpClient mockAsyncHttpClient; + private ChecksumCapturingInterceptor checksumCapturingInterceptor; + private S3AsyncClient s3Client; + + @BeforeEach + void init() { + this.mockAsyncHttpClient = new MockAsyncHttpClient(); + this.checksumCapturingInterceptor = new ChecksumCapturingInterceptor(); + s3Client = S3AsyncClient.builder() + .httpClient(mockAsyncHttpClient) + .endpointOverride(URI.create("http://localhost")) + .overrideConfiguration(c -> c.addExecutionInterceptor(checksumCapturingInterceptor)) + .multipartEnabled(true) + .region(Region.US_EAST_1) + .build(); + } + + @AfterEach + void reset() { + this.mockAsyncHttpClient.reset(); + } + + @Test + public void putObject_default_shouldAddCrc32() { + HttpExecuteResponse response = HttpExecuteResponse.builder() + .response(SdkHttpResponse.builder().statusCode(200).build()) + .build(); + mockAsyncHttpClient.stubResponses(response); + + PutObjectRequest putObjectRequest = putObjectRequestBuilder().build(); + + s3Client.putObject(putObjectRequest, AsyncRequestBody.fromString("hello world")); + assertThat(checksumCapturingInterceptor.checksumHeader).isEqualTo("CRC32"); + } + + @Test + public void putObject_withNonCrc32ChecksumType_shouldNotAddCrc32() { + HttpExecuteResponse response = HttpExecuteResponse.builder() + .response(SdkHttpResponse.builder().statusCode(200).build()) + .build(); + mockAsyncHttpClient.stubResponses(response); + + PutObjectRequest putObjectRequest = + putObjectRequestBuilder() + .checksumAlgorithm(ChecksumAlgorithm.SHA256) + .build(); + + s3Client.putObject(putObjectRequest, AsyncRequestBody.fromString("hello world")); + assertThat(checksumCapturingInterceptor.checksumHeader).isEqualTo("SHA256"); + } + + @Test + public void putObject_withNonCrc32ChecksumValue_shouldNotAddCrc32() { + HttpExecuteResponse response = HttpExecuteResponse.builder() + .response(SdkHttpResponse.builder().statusCode(200).build()) + .build(); + mockAsyncHttpClient.stubResponses(response); + + PutObjectRequest putObjectRequest = + putObjectRequestBuilder() + .checksumSHA256("checksumVal") + .build(); + + s3Client.putObject(putObjectRequest, AsyncRequestBody.fromString("hello world")); + assertThat(checksumCapturingInterceptor.checksumHeader).isNull(); + assertThat(checksumCapturingInterceptor.headers.get("x-amz-checksum-sha256")).contains("checksumVal"); + } + + @Test + public void putObject_withCrc32Value_shouldNotAddCrc32TypeHeader() { + HttpExecuteResponse response = HttpExecuteResponse.builder() + .response(SdkHttpResponse.builder().statusCode(200).build()) + .build(); + mockAsyncHttpClient.stubResponses(response); + + PutObjectRequest putObjectRequest = + putObjectRequestBuilder() + .checksumCRC32("checksumVal") + .build(); + + s3Client.putObject(putObjectRequest, AsyncRequestBody.fromString("hello world")); + assertThat(checksumCapturingInterceptor.checksumHeader).isNull(); + assertThat(checksumCapturingInterceptor.headers.get("x-amz-checksum-crc32")).contains("checksumVal"); + } + + private PutObjectRequest.Builder putObjectRequestBuilder() { + return PutObjectRequest.builder().bucket("bucket").key("key"); + } + + private static final class ChecksumCapturingInterceptor implements ExecutionInterceptor { + String checksumHeader; + Map> headers; + + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + SdkHttpRequest sdkHttpRequest = context.httpRequest(); + headers = sdkHttpRequest.headers(); + String checksumHeaderName = "x-amz-sdk-checksum-algorithm"; + if (headers.containsKey(checksumHeaderName)) { + List checksumHeaderVals = headers.get(checksumHeaderName); + assertThat(checksumHeaderVals).hasSize(1); + checksumHeader = checksumHeaderVals.get(0); + } + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java index 4d5a333a51dd..0c7e79f2d2c5 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java @@ -37,6 +37,7 @@ import software.amazon.awssdk.http.SdkHttpFullResponse; import software.amazon.awssdk.services.s3.internal.multipart.SdkPojoConversionUtils; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CopyObjectRequest; @@ -182,7 +183,6 @@ void toCreateMultipartUploadRequest_putObjectRequest_shouldCopyProperties() { PutObjectRequest randomObject = randomPutObjectRequest(); CreateMultipartUploadRequest convertedObject = SdkPojoConversionUtils.toCreateMultipartUploadRequest(randomObject); Set fieldsToIgnore = new HashSet<>(); - System.out.println(convertedObject); verifyFieldsAreCopied(randomObject, convertedObject, fieldsToIgnore, PutObjectRequest.builder().sdkFields(), CreateMultipartUploadRequest.builder().sdkFields()); @@ -201,6 +201,23 @@ void toCompletedPart_putObject_shouldCopyProperties() { assertThat(convertedCompletedPart.partNumber()).isEqualTo(1); } + @Test + void toCompleteMultipartUploadRequest_putObject_shouldCopyProperties() { + PutObjectRequest randomObject = randomPutObjectRequest(); + CompletedPart parts[] = new CompletedPart[1]; + CompletedPart completedPart = CompletedPart.builder().partNumber(1).build(); + parts[0] = completedPart; + CompleteMultipartUploadRequest convertedObject = + SdkPojoConversionUtils.toCompleteMultipartUploadRequest(randomObject, "uploadId", parts); + + Set fieldsToIgnore = new HashSet<>(); + verifyFieldsAreCopied(randomObject, convertedObject, fieldsToIgnore, + PutObjectRequest.builder().sdkFields(), + CompleteMultipartUploadRequest.builder().sdkFields()); + assertThat(convertedObject.uploadId()).isEqualTo("uploadId"); + assertThat(convertedObject.multipartUpload().parts()).contains(completedPart); + } + private static void verifyFieldsAreCopied(SdkPojo requestConvertedFrom, SdkPojo requestConvertedTo, Set fieldsToIgnore,