Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resume partial download from S3 on connection drop #46589

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import com.amazonaws.AmazonClientException;
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
import com.amazonaws.services.s3.model.AmazonS3Exception;
import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest;
import com.amazonaws.services.s3.model.DeleteObjectsRequest;
import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest;
Expand All @@ -31,7 +30,6 @@
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.PartETag;
import com.amazonaws.services.s3.model.PutObjectRequest;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.UploadPartRequest;
import com.amazonaws.services.s3.model.UploadPartResult;
import org.apache.lucene.util.SetOnce;
Expand All @@ -48,7 +46,6 @@

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.NoSuchFileException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -81,18 +78,7 @@ class S3BlobContainer extends AbstractBlobContainer {

@Override
public InputStream readBlob(String blobName) throws IOException {
try (AmazonS3Reference clientReference = blobStore.clientReference()) {
final S3Object s3Object = SocketAccess.doPrivileged(() -> clientReference.client().getObject(blobStore.bucket(),
buildKey(blobName)));
return s3Object.getObjectContent();
} catch (final AmazonClientException e) {
if (e instanceof AmazonS3Exception) {
if (404 == ((AmazonS3Exception) e).getStatusCode()) {
throw new NoSuchFileException("Blob object [" + blobName + "] not found: " + e.getMessage());
}
}
throw e;
}
return new S3RetryingInputStream(blobStore, buildKey(blobName));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ public AmazonS3Reference clientReference() {
return service.client(repositoryMetaData);
}

int getMaxRetries() {
return service.settings(repositoryMetaData).maxRetries;
}

public String bucket() {
return bucket;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.repositories.s3;

import com.amazonaws.AmazonClientException;
import com.amazonaws.services.s3.model.AmazonS3Exception;
import com.amazonaws.services.s3.model.GetObjectRequest;
import com.amazonaws.services.s3.model.S3Object;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.core.internal.io.IOUtils;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.NoSuchFileException;

class S3RetryingInputStream extends InputStream {

private static final Logger logger = LogManager.getLogger(S3RetryingInputStream.class);

private final S3BlobStore blobStore;
private final String blobKey;
private final int maxAttempts;

private InputStream currentStream;
private long currentOffset;

S3RetryingInputStream(S3BlobStore blobStore, String blobKey) throws IOException {
this.blobStore = blobStore;
this.blobKey = blobKey;
this.maxAttempts = blobStore.getMaxRetries() + 1;
currentStream = openStream();
}

private InputStream openStream() throws IOException {
try (AmazonS3Reference clientReference = blobStore.clientReference()) {
final GetObjectRequest getObjectRequest = new GetObjectRequest(blobStore.bucket(), blobKey);
if (currentOffset > 0) {
getObjectRequest.setRange(currentOffset);
}
final S3Object s3Object = SocketAccess.doPrivileged(() -> clientReference.client().getObject(getObjectRequest));
return s3Object.getObjectContent();
} catch (final AmazonClientException e) {
if (e instanceof AmazonS3Exception) {
if (404 == ((AmazonS3Exception) e).getStatusCode()) {
throw new NoSuchFileException("Blob object [" + blobKey + "] not found: " + e.getMessage());
}
}
throw e;
}
}

@Override
public int read() throws IOException {
int attempt = 0;
while (true) {
attempt += 1;
try {
final int result = currentStream.read();
currentOffset += 1;
return result;
} catch (IOException e) {
reopenStreamOrFail(attempt, e);
}
}
}

@Override
public int read(byte[] b, int off, int len) throws IOException {
int attempt = 0;
while (true) {
attempt += 1;
try {
final int bytesRead = currentStream.read(b, off, len);
if (bytesRead == -1) {
return -1;
}
currentOffset += bytesRead;
return bytesRead;
} catch (IOException e) {
reopenStreamOrFail(attempt, e);
}
}
}

private void reopenStreamOrFail(int attempt, IOException e) throws IOException {
if (attempt >= maxAttempts) {
throw e;
}
logger.debug(new ParameterizedMessage("failed reading [{}/{}] at offset [{}], attempt [{}] of [{}], retrying",
blobStore.bucket(), blobKey, currentOffset, attempt, maxAttempts), e);
IOUtils.closeWhileHandlingException(currentStream);
currentStream = openStream();
}

@Override
public void close() throws IOException {
currentStream.close();
DaveCTurner marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
public long skip(long n) {
throw new UnsupportedOperationException("S3RetryingInputStream does not support seeking");
}

@Override
public synchronized void reset() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to make this synchronized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, not sure where that came from. Fixed in 3f8c20e.

throw new UnsupportedOperationException("S3RetryingInputStream does not support seeking");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public AmazonS3Reference client(RepositoryMetaData repositoryMetaData) {
* @param repositoryMetaData Repository Metadata
* @return S3ClientSettings
*/
private S3ClientSettings settings(RepositoryMetaData repositoryMetaData) {
S3ClientSettings settings(RepositoryMetaData repositoryMetaData) {
final String clientName = S3Repository.CLIENT_NAME.get(repositoryMetaData.settings());
final S3ClientSettings staticSettings = staticClientSettings.get(clientName);
if (staticSettings != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import com.amazonaws.SdkClientException;
import com.amazonaws.services.s3.internal.MD5DigestCalculatingInputStream;
import com.amazonaws.util.Base16;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer;
import org.apache.http.ConnectionClosedException;
import org.apache.http.HttpStatus;
import org.apache.http.NoHttpResponseException;
import org.elasticsearch.cluster.metadata.RepositoryMetaData;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.SuppressForbidden;
Expand Down Expand Up @@ -51,12 +54,15 @@
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.nio.charset.StandardCharsets;
import java.nio.file.NoSuchFileException;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.elasticsearch.repositories.s3.S3ClientSettings.DISABLE_CHUNKED_ENCODING;
import static org.elasticsearch.repositories.s3.S3ClientSettings.ENDPOINT_SETTING;
Expand All @@ -67,6 +73,7 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;

/**
* This class tests how a {@link S3BlobContainer} and its underlying AWS S3 client are retrying requests when reading or writing blobs.
Expand Down Expand Up @@ -130,35 +137,51 @@ private BlobContainer createBlobContainer(final @Nullable Integer maxRetries,
repositoryMetaData));
}

public void testReadNonexistentBlobThrowsNoSuchFileException() {
final BlobContainer blobContainer = createBlobContainer(between(1, 5), null, null, null);
final Exception exception = expectThrows(NoSuchFileException.class, () -> blobContainer.readBlob("read_nonexistent_blob"));
assertThat(exception.getMessage().toLowerCase(Locale.ROOT),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can fit on the same line

containsString("blob object [read_nonexistent_blob] not found"));
}

public void testReadBlobWithRetries() throws Exception {
final int maxRetries = randomInt(5);
final CountDown countDown = new CountDown(maxRetries + 1);

final byte[] bytes = randomByteArrayOfLength(randomIntBetween(1, 512));
final byte[] bytes = randomBlobContent();
httpServer.createContext("/bucket/read_blob_max_retries", exchange -> {
Streams.readFully(exchange.getRequestBody());
if (countDown.countDown()) {
final int rangeStart = getRangeStart(exchange);
assertThat(rangeStart, lessThan(bytes.length));
exchange.getResponseHeaders().add("Content-Type", "text/plain; charset=utf-8");
exchange.sendResponseHeaders(HttpStatus.SC_OK, bytes.length);
exchange.getResponseBody().write(bytes);
exchange.sendResponseHeaders(HttpStatus.SC_OK, bytes.length - rangeStart);
exchange.getResponseBody().write(bytes, rangeStart, bytes.length - rangeStart);
exchange.close();
return;
}
exchange.sendResponseHeaders(randomFrom(HttpStatus.SC_INTERNAL_SERVER_ERROR, HttpStatus.SC_BAD_GATEWAY,
HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_GATEWAY_TIMEOUT), -1);
exchange.close();
if (randomBoolean()) {
exchange.sendResponseHeaders(randomFrom(HttpStatus.SC_INTERNAL_SERVER_ERROR, HttpStatus.SC_BAD_GATEWAY,
HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_GATEWAY_TIMEOUT), -1);
} else if (randomBoolean()) {
sendIncompleteContent(exchange, bytes);
}
if (randomBoolean()) {
exchange.close();
}
});

final BlobContainer blobContainer = createBlobContainer(maxRetries, null, null, null);
final TimeValue readTimeout = TimeValue.timeValueMillis(between(100, 500));
final BlobContainer blobContainer = createBlobContainer(maxRetries, readTimeout, null, null);
try (InputStream inputStream = blobContainer.readBlob("read_blob_max_retries")) {
assertArrayEquals(bytes, BytesReference.toBytes(Streams.readFully(inputStream)));
assertThat(countDown.isCountedDown(), is(true));
}
}

public void testReadBlobWithReadTimeouts() {
final TimeValue readTimeout = TimeValue.timeValueMillis(randomIntBetween(100, 500));
final BlobContainer blobContainer = createBlobContainer(1, readTimeout, null, null);
final TimeValue readTimeout = TimeValue.timeValueMillis(between(100, 200));
final BlobContainer blobContainer = createBlobContainer(between(1, 5), readTimeout, null, null);

// HTTP server does not send a response
httpServer.createContext("/bucket/read_blob_unresponsive", exchange -> {});
Expand All @@ -168,15 +191,8 @@ public void testReadBlobWithReadTimeouts() {
assertThat(exception.getCause(), instanceOf(SocketTimeoutException.class));

// HTTP server sends a partial response
final byte[] bytes = randomByteArrayOfLength(randomIntBetween(10, 128));
httpServer.createContext("/bucket/read_blob_incomplete", exchange -> {
exchange.getResponseHeaders().add("Content-Type", "text/plain; charset=utf-8");
exchange.sendResponseHeaders(HttpStatus.SC_OK, bytes.length);
exchange.getResponseBody().write(bytes, 0, randomIntBetween(1, bytes.length - 1));
if (randomBoolean()) {
exchange.getResponseBody().flush();
}
});
final byte[] bytes = randomBlobContent();
httpServer.createContext("/bucket/read_blob_incomplete", exchange -> sendIncompleteContent(exchange, bytes));

exception = expectThrows(SocketTimeoutException.class, () -> {
try (InputStream stream = blobContainer.readBlob("read_blob_incomplete")) {
Expand All @@ -186,11 +202,62 @@ public void testReadBlobWithReadTimeouts() {
assertThat(exception.getMessage().toLowerCase(Locale.ROOT), containsString("read timed out"));
}

public void testReadBlobWithPrematureConnectionClose() {
final BlobContainer blobContainer = createBlobContainer(between(1, 5), null, null, null);

// HTTP server closes connection immediately
httpServer.createContext("/bucket/read_blob_no_response", HttpExchange::close);

Exception exception = expectThrows(SdkClientException.class, () -> blobContainer.readBlob("read_blob_no_response"));
assertThat(exception.getMessage().toLowerCase(Locale.ROOT), containsString("the target server failed to respond"));
assertThat(exception.getCause(), instanceOf(NoHttpResponseException.class));

// HTTP server sends a partial response
final byte[] bytes = randomBlobContent();
httpServer.createContext("/bucket/read_blob_incomplete", exchange -> {
sendIncompleteContent(exchange, bytes);
exchange.close();
});

exception = expectThrows(ConnectionClosedException.class, () -> {
try (InputStream stream = blobContainer.readBlob("read_blob_incomplete")) {
Streams.readFully(stream);
}
});
assertThat(exception.getMessage().toLowerCase(Locale.ROOT),
containsString("premature end of content-length delimited message body"));
}

private static int getRangeStart(HttpExchange exchange) {
DaveCTurner marked this conversation as resolved.
Show resolved Hide resolved
final String rangeHeader = exchange.getRequestHeaders().getFirst("Range");
if (rangeHeader == null) {
return 0;
}

final Matcher matcher = Pattern.compile("^bytes=([0-9]+)-9223372036854775806$").matcher(rangeHeader);
assertTrue(rangeHeader + " matches expected pattern", matcher.matches());
return Math.toIntExact(Long.parseLong(matcher.group(1)));
}

private void sendIncompleteContent(HttpExchange exchange, byte[] bytes) throws IOException {
DaveCTurner marked this conversation as resolved.
Show resolved Hide resolved
final int rangeStart = getRangeStart(exchange);
assertThat(rangeStart, lessThan(bytes.length));
exchange.getResponseHeaders().add("Content-Type", "text/plain; charset=utf-8");
exchange.sendResponseHeaders(HttpStatus.SC_OK, bytes.length - rangeStart);
final int bytesToSend = randomIntBetween(0, bytes.length - rangeStart - 1);
if (bytesToSend > 0) {
exchange.getResponseBody().write(bytes, rangeStart, bytesToSend);
}
if (randomBoolean()) {
exchange.getResponseBody().flush();
}
}

public void testWriteBlobWithRetries() throws Exception {
final int maxRetries = randomInt(5);
final CountDown countDown = new CountDown(maxRetries + 1);

final byte[] bytes = randomByteArrayOfLength(randomIntBetween(1, frequently() ? 512 : 1 << 20)); // rarely up to 1mb
final byte[] bytes = randomBlobContent();
httpServer.createContext("/bucket/write_blob_max_retries", exchange -> {
if ("PUT".equals(exchange.getRequestMethod()) && exchange.getRequestURI().getQuery() == null) {
if (countDown.countDown()) {
Expand Down Expand Up @@ -224,6 +291,10 @@ public void testWriteBlobWithRetries() throws Exception {
assertThat(countDown.isCountedDown(), is(true));
}

private byte[] randomBlobContent() {
DaveCTurner marked this conversation as resolved.
Show resolved Hide resolved
return randomByteArrayOfLength(randomIntBetween(1, frequently() ? 512 : 1 << 20)); // rarely up to 1mb
}

public void testWriteBlobWithReadTimeouts() {
final byte[] bytes = randomByteArrayOfLength(randomIntBetween(10, 128));
final TimeValue readTimeout = TimeValue.timeValueMillis(randomIntBetween(100, 500));
Expand Down