Skip to content

Commit

Permalink
Reduce memory usage for chunk-encoded streaming uploads, like those u…
Browse files Browse the repository at this point in the history
…sed by flexible checksums in S3. (#4858)

Before this change, our chunk encoding logic would copy customer data five times:
1. [From the customer's stream into a byte array.](https://github.com/aws/aws-sdk-java-v2/blob/6040b2be6731e4b5ef64e775a2cfffb07d76766c/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedInputStream.java#L106-L107)
2. [From the byte array into a slightly smaller byte array.](https://github.com/aws/aws-sdk-java-v2/blob/6040b2be6731e4b5ef64e775a2cfffb07d76766c/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedInputStream.java#L111)
3. [From the smaller byte array into a byte array output stream.](https://github.com/aws/aws-sdk-java-v2/blob/6040b2be6731e4b5ef64e775a2cfffb07d76766c/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedInputStream.java#L171)
4. [From the byte array output stream into an array.](https://github.com/aws/aws-sdk-java-v2/blob/6040b2be6731e4b5ef64e775a2cfffb07d76766c/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedInputStream.java#L149)
5. [From the array into the output array.](https://github.com/aws/aws-sdk-java-v2/blob/6040b2be6731e4b5ef64e775a2cfffb07d76766c/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedInputStream.java#L85)

After this change, the logic will copy the data twice:
1. From the customer's stream into a byte array.
2. From the byte array into the output array.

There's a path to make it only one copy, but it requires the chunk encoded input stream to know the length of the underlying stream so that it can detect when the last chunk will be encountered. This will require additional piping, so we can do it in a follow-up PR.
  • Loading branch information
millems authored Jan 30, 2024
1 parent 3b86c3b commit f697b8c
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 67 deletions.
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-a5aec87.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Reduce how many times input data is copied when writing to chunked encoded operations, like S3's PutObject."
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public ContentStreamProvider sign(ContentStreamProvider payload, V4aRequestSigni
.builder()
.inputStream(inputStream)
.chunkSize(chunkSize)
.header(chunk -> Integer.toHexString(chunk.length).getBytes(StandardCharsets.UTF_8));
.header(chunk -> Integer.toHexString(chunk.remaining()).getBytes(StandardCharsets.UTF_8));

preExistingTrailers.forEach(trailer -> chunkedEncodedInputStreamBuilder.addTrailer(() -> trailer));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

package software.amazon.awssdk.http.auth.aws.crt.internal.signer;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -45,14 +46,14 @@ public RollingSigner(byte[] seedSignature, AwsSigningConfig signingConfig) {
this.signingConfig = signingConfig;
}

private static byte[] signChunk(byte[] chunkBody, byte[] previousSignature, AwsSigningConfig signingConfig) {
private static byte[] signChunk(ByteBuffer chunkBody, byte[] previousSignature, AwsSigningConfig signingConfig) {
// All the config remains the same as signing config except the Signature Type.
AwsSigningConfig configCopy = signingConfig.clone();
configCopy.setSignatureType(AwsSigningConfig.AwsSignatureType.HTTP_REQUEST_CHUNK);
configCopy.setSignedBodyHeader(AwsSigningConfig.AwsSignedBodyHeaderType.NONE);
configCopy.setSignedBodyValue(null);

HttpRequestBodyStream crtBody = new CrtInputStream(() -> new ByteArrayInputStream(chunkBody));
HttpRequestBodyStream crtBody = new CrtInputStream(() -> new ByteBufferBackedInputStream(chunkBody));
return CompletableFutureUtils.joinLikeSync(AwsSigner.signChunk(crtBody, previousSignature, configCopy));
}

Expand All @@ -75,7 +76,7 @@ private static AwsSigningResult signTrailerHeaders(Map<String, List<String>> hea
/**
* Using a template that incorporates the previous calculated signature, sign the string and return it.
*/
public byte[] sign(byte[] chunkBody) {
public byte[] sign(ByteBuffer chunkBody) {
previousSignature = signChunk(chunkBody, previousSignature, signingConfig);
return previousSignature;
}
Expand All @@ -89,4 +90,29 @@ public byte[] sign(Map<String, List<String>> headerMap) {
public void reset() {
previousSignature = seedSignature;
}

private static class ByteBufferBackedInputStream extends InputStream {
private final ByteBuffer buf;

private ByteBufferBackedInputStream(ByteBuffer buf) {
this.buf = buf;
}

public int read() {
if (!buf.hasRemaining()) {
return -1;
}
return buf.get() & 0xFF;
}

public int read(byte[] bytes, int off, int len) {
if (!buf.hasRemaining()) {
return -1;
}

len = Math.min(len, buf.remaining());
buf.get(bytes, off, len);
return len;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.http.auth.aws.crt.internal.signer;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.http.auth.aws.internal.signer.CredentialScope;
Expand All @@ -38,11 +39,8 @@ public void reset() {
}

@Override
public Pair<byte[], byte[]> get(byte[] chunk) {
public Pair<byte[], byte[]> get(ByteBuffer chunk) {
byte[] chunkSig = signer.sign(chunk);
return Pair.of(
"chunk-signature".getBytes(StandardCharsets.UTF_8),
chunkSig
);
return Pair.of("chunk-signature".getBytes(StandardCharsets.UTF_8), chunkSig);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public ContentStreamProvider sign(ContentStreamProvider payload, V4RequestSignin
.builder()
.inputStream(payload.newStream())
.chunkSize(chunkSize)
.header(chunk -> Integer.toHexString(chunk.length).getBytes(StandardCharsets.UTF_8));
.header(chunk -> Integer.toHexString(chunk.remaining()).getBytes(StandardCharsets.UTF_8));

preExistingTrailers.forEach(trailer -> chunkedEncodedInputStreamBuilder.addTrailer(() -> trailer));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding;

import java.nio.ByteBuffer;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.Pair;

Expand All @@ -32,5 +33,5 @@
@FunctionalInterface
@SdkInternalApi
public interface ChunkExtensionProvider extends Resettable {
Pair<byte[], byte[]> get(byte[] chunk);
Pair<byte[], byte[]> get(ByteBuffer chunk);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding;

import java.nio.ByteBuffer;
import software.amazon.awssdk.annotations.SdkInternalApi;

/**
Expand All @@ -27,5 +28,5 @@
@FunctionalInterface
@SdkInternalApi
public interface ChunkHeaderProvider extends Resettable {
byte[] get(byte[] chunk);
byte[] get(ByteBuffer chunk);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.SequenceInputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.Logger;
Expand Down Expand Up @@ -52,6 +53,10 @@ public final class ChunkedEncodedInputStream extends InputStream {
private static final Logger LOG = Logger.loggerFor(ChunkedEncodedInputStream.class);
private static final byte[] CRLF = {'\r', '\n'};
private static final byte[] END = {};
private static final byte[] SEMICOLON = {';'};
private static final byte[] EQUALS = {'='};
private static final byte[] COLON = {':'};
private static final byte[] COMMA = {','};

private final InputStream inputStream;
private final int chunkSize;
Expand Down Expand Up @@ -101,14 +106,14 @@ private Chunk getChunk(InputStream stream) throws IOException {
if (currentChunk != null) {
currentChunk.close();
}
// we *have* to read from the backing stream in order to figure out if it's the end or not
// TODO(sra-identity-and-auth): We can likely optimize this by not copying the entire chunk of data into memory

// We have to read from the input stream into a format that can be used for signing and headers.
byte[] chunkData = new byte[chunkSize];
int read = read(stream, chunkData, chunkSize);

if (read > 0) {
// set the current chunk to the newly written chunk
return getNextChunk(Arrays.copyOf(chunkData, read));
return getNextChunk(ByteBuffer.wrap(chunkData, 0, read));
}

LOG.debug(() -> "End of backing stream reached. Reading final chunk.");
Expand Down Expand Up @@ -142,58 +147,71 @@ private int read(InputStream inputStream, byte[] buf, int maxBytesToRead) throws
* Create a chunk from a byte-array, which includes the header, the extensions, and the chunk data. The input array should be
* correctly sized, i.e. the number of bytes should equal its length.
*/
private Chunk getNextChunk(byte[] data) throws IOException {
ByteArrayOutputStream chunkStream = new ByteArrayOutputStream();
writeChunk(data, chunkStream);
chunkStream.write(CRLF);
byte[] newChunkData = chunkStream.toByteArray();

return Chunk.create(new ByteArrayInputStream(newChunkData), newChunkData.length);
private Chunk getNextChunk(ByteBuffer data) {
LengthAwareSequenceInputStream newChunkData =
LengthAwareSequenceInputStream.builder()
.add(createChunkStream(data))
.add(CRLF)
.build();
return Chunk.create(newChunkData, newChunkData.size);
}

/**
* Create the final chunk, which includes the header, the extensions, the chunk (if applicable), and the trailer
*/
private Chunk getFinalChunk() throws IOException {
ByteArrayOutputStream chunkStream = new ByteArrayOutputStream();
writeChunk(END, chunkStream);
writeTrailers(chunkStream);
chunkStream.write(CRLF);
byte[] newChunkData = chunkStream.toByteArray();

return Chunk.create(new ByteArrayInputStream(newChunkData), newChunkData.length);
LengthAwareSequenceInputStream chunkData =
LengthAwareSequenceInputStream.builder()
.add(createChunkStream(ByteBuffer.wrap(END)))
.add(createTrailerStream())
.add(CRLF)
.build();

return Chunk.create(chunkData, chunkData.size);
}

private void writeChunk(byte[] chunk, ByteArrayOutputStream outputStream) throws IOException {
writeHeader(chunk, outputStream);
writeExtensions(chunk, outputStream);
outputStream.write(CRLF);
outputStream.write(chunk);
private LengthAwareSequenceInputStream createChunkStream(ByteBuffer chunkData) {
return LengthAwareSequenceInputStream.builder()
.add(createHeaderStream(chunkData.asReadOnlyBuffer()))
.add(createExtensionsStream(chunkData.asReadOnlyBuffer()))
.add(CRLF)
.add(new ByteArrayInputStream(chunkData.array(),
chunkData.arrayOffset(),
chunkData.remaining()))
.build();
}

private void writeHeader(byte[] chunk, ByteArrayOutputStream outputStream) throws IOException {
byte[] hdr = header.get(chunk);
outputStream.write(hdr);
private ByteArrayInputStream createHeaderStream(ByteBuffer chunkData) {
return new ByteArrayInputStream(header.get(chunkData));
}

private void writeExtensions(byte[] chunk, ByteArrayOutputStream outputStream) throws IOException {
private LengthAwareSequenceInputStream createExtensionsStream(ByteBuffer chunkData) {
LengthAwareSequenceInputStream.Builder result = LengthAwareSequenceInputStream.builder();
for (ChunkExtensionProvider chunkExtensionProvider : extensions) {
Pair<byte[], byte[]> ext = chunkExtensionProvider.get(chunk);
outputStream.write((byte) ';');
outputStream.write(ext.left());
outputStream.write((byte) '=');
outputStream.write(ext.right());
Pair<byte[], byte[]> ext = chunkExtensionProvider.get(chunkData);
result.add(SEMICOLON);
result.add(ext.left());
result.add(EQUALS);
result.add(ext.right());
}
return result.build();
}

private void writeTrailers(ByteArrayOutputStream outputStream) throws IOException {
private LengthAwareSequenceInputStream createTrailerStream() throws IOException {
LengthAwareSequenceInputStream.Builder result = LengthAwareSequenceInputStream.builder();
for (TrailerProvider trailer : trailers) {
Pair<String, List<String>> tlr = trailer.get();
outputStream.write(tlr.left().getBytes(StandardCharsets.UTF_8));
outputStream.write((byte) ':');
outputStream.write(String.join(",", tlr.right()).getBytes(StandardCharsets.UTF_8));
outputStream.write(CRLF);
result.add(tlr.left().getBytes(StandardCharsets.UTF_8));
result.add(COLON);
for (String trailerValue : tlr.right()) {
result.add(trailerValue.getBytes(StandardCharsets.UTF_8));
result.add(COMMA);
}

// Replace trailing comma with clrf
result.replaceLast(new ByteArrayInputStream(CRLF), COMMA.length);
}
return result.build();
}

@Override
Expand All @@ -216,7 +234,8 @@ public static class Builder {
private final List<TrailerProvider> trailers = new ArrayList<>();
private InputStream inputStream;
private int chunkSize;
private ChunkHeaderProvider header = chunk -> Integer.toHexString(chunk.length).getBytes(StandardCharsets.UTF_8);
private ChunkHeaderProvider header =
chunk -> Integer.toHexString(chunk.remaining()).getBytes(StandardCharsets.UTF_8);

public InputStream inputStream() {
return this.inputStream;
Expand Down Expand Up @@ -267,5 +286,51 @@ public ChunkedEncodedInputStream build() {
return new ChunkedEncodedInputStream(this);
}
}


private static class LengthAwareSequenceInputStream extends SequenceInputStream {
private final int size;

private LengthAwareSequenceInputStream(Builder builder) {
super(Collections.enumeration(builder.streams));
this.size = builder.size;
}

private static Builder builder() {
return new Builder();
}

private static class Builder {
private final List<InputStream> streams = new ArrayList<>();
private int size = 0;

public Builder add(ByteArrayInputStream stream) {
streams.add(stream);
size += stream.available();
return this;
}

public Builder add(byte[] stream) {
return add(new ByteArrayInputStream(stream));
}

public Builder add(LengthAwareSequenceInputStream stream) {
streams.add(stream);
size += stream.size;
return this;
}

public Builder replaceLast(ByteArrayInputStream stream, int lastLength) {
streams.set(streams.size() - 1, stream);
size -= lastLength;
size += stream.available();
return this;
}

public LengthAwareSequenceInputStream build() {
return new LengthAwareSequenceInputStream(this);
}
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils.hash;
import static software.amazon.awssdk.utils.BinaryUtils.toHex;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.http.auth.aws.internal.signer.CredentialScope;
Expand All @@ -42,7 +43,7 @@ public void reset() {
signer.reset();
}

private String getStringToSign(String previousSignature, byte[] chunk) {
private String getStringToSign(String previousSignature, ByteBuffer chunk) {
// build the string-to-sign template for the rolling-signer to sign
return String.join("\n",
"AWS4-HMAC-SHA256-PAYLOAD",
Expand All @@ -55,11 +56,9 @@ private String getStringToSign(String previousSignature, byte[] chunk) {
}

@Override
public Pair<byte[], byte[]> get(byte[] chunk) {
public Pair<byte[], byte[]> get(ByteBuffer chunk) {
String chunkSig = signer.sign(previousSig -> getStringToSign(previousSig, chunk));
return Pair.of(
"chunk-signature".getBytes(StandardCharsets.UTF_8),
chunkSig.getBytes(StandardCharsets.UTF_8)
);
return Pair.of("chunk-signature".getBytes(StandardCharsets.UTF_8),
chunkSig.getBytes(StandardCharsets.UTF_8));
}
}
Loading

0 comments on commit f697b8c

Please sign in to comment.