diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index 75f5990a1ccf2..fcfccf50ad326 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -241,17 +241,20 @@ public void readBlobAsync(String blobName, ActionListener listener) return; } - final List> blobPartInputStreamFutures = new ArrayList<>(); + final List blobPartInputStreamFutures = new ArrayList<>(); final long blobSize = blobMetadata.objectSize(); final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount(); final String blobChecksum = blobMetadata.checksum().checksumCRC32(); if (numberOfParts == null) { - blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null)); + blobPartInputStreamFutures.add(() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null)); } else { // S3 multipart files use 1 to n indexing for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) { - blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber)); + final int innerPartNumber = partNumber; + blobPartInputStreamFutures.add( + () -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, innerPartNumber) + ); } } listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum)); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java index 055a882885065..2e54705e9cd78 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java @@ -969,7 +969,7 @@ public void testReadBlobAsyncMultiPart() throws Exception { assertEquals(objectSize, readContext.getBlobSize()); for (int partNumber = 1; partNumber < objectPartCount; partNumber++) { - InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get(); + InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get().join(); final int offset = partNumber * partSize; assertEquals(partSize, inputStreamContainer.getContentLength()); assertEquals(offset, inputStreamContainer.getOffset()); @@ -1024,7 +1024,7 @@ public void testReadBlobAsyncSinglePart() throws Exception { assertEquals(checksum, readContext.getBlobChecksum()); assertEquals(objectSize, readContext.getBlobSize()); - InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get(); + InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get().join(); assertEquals(objectSize, inputStreamContainer.getContentLength()); assertEquals(0, inputStreamContainer.getOffset()); assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length); diff --git a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java index 934656cdf702b..36987ac2d4991 100644 --- a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java +++ b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java @@ -125,11 +125,11 @@ public void readBlobAsync(String blobName, ActionListener listener) long contentLength = listBlobs().get(blobName).length(); long partSize = contentLength / 10; int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); - List> blobPartStreams = new ArrayList<>(); + List blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < numberOfParts; partNumber++) { long offset = partNumber * partSize; InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset); - blobPartStreams.add(CompletableFuture.completedFuture(blobPartStream)); + blobPartStreams.add(() -> CompletableFuture.completedFuture(blobPartStream)); } ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null); listener.onResponse(blobReadContext); diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java index 220178d587ac4..97f304d776f5c 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java @@ -10,12 +10,10 @@ import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; -import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.core.action.ActionListener; import java.io.IOException; -import java.nio.file.Path; /** * An extension of {@link BlobContainer} that adds {@link AsyncMultiStreamBlobContainer#asyncBlobUpload} to allow @@ -44,18 +42,6 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer { @ExperimentalApi void readBlobAsync(String blobName, ActionListener listener); - /** - * Asynchronously downloads the blob to the specified location using an executor from the thread pool. - * @param blobName The name of the blob for which needs to be downloaded. - * @param fileLocation The path on local disk where the blob needs to be downloaded. - * @param completionListener Listener which will be notified when the download is complete. - */ - @ExperimentalApi - default void asyncBlobDownload(String blobName, Path fileLocation, ActionListener completionListener) { - ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, completionListener); - readBlobAsync(blobName, readContextListener); - } - /* * Wether underlying blobContainer can verify integrity of data after transfer. If true and if expected * checksum is provided in WriteContext, then the checksum of transferred data is compared with expected checksum diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java index 5637326915746..82bc7a0baed50 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.io.InputStream; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; /** @@ -145,9 +144,9 @@ public long getBlobSize() { } @Override - public List> getPartStreams() { + public List getPartStreams() { return super.getPartStreams().stream() - .map(cf -> cf.thenApply(this::decryptInputStreamContainer)) + .map(supplier -> (StreamPartCreator) () -> supplier.get().thenApply(this::decryptInputStreamContainer)) .collect(Collectors.toUnmodifiableList()); } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java index dc3e2e931c7d3..4bdce11ff4f9a 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java @@ -13,6 +13,7 @@ import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; /** * ReadContext is used to encapsulate all data needed by BlobContainer#readBlobAsync @@ -20,10 +21,10 @@ @ExperimentalApi public class ReadContext { private final long blobSize; - private final List> asyncPartStreams; + private final List asyncPartStreams; private final String blobChecksum; - public ReadContext(long blobSize, List> asyncPartStreams, String blobChecksum) { + public ReadContext(long blobSize, List asyncPartStreams, String blobChecksum) { this.blobSize = blobSize; this.asyncPartStreams = asyncPartStreams; this.blobChecksum = blobChecksum; @@ -47,7 +48,23 @@ public long getBlobSize() { return blobSize; } - public List> getPartStreams() { + public List getPartStreams() { return asyncPartStreams; } + + /** + * Functional interface defining an instance that can create an async action + * to create a part of an object represented as an InputStreamContainer. + */ + @FunctionalInterface + public interface StreamPartCreator extends Supplier> { + /** + * Kicks off a async process to start streaming. + * + * @return When the returned future is completed, streaming has + * just begun. Clients must fully consume the resulting stream. + */ + @Override + CompletableFuture get(); + } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java index 0eae22220ea82..1bd5ec6985fa2 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java @@ -8,89 +8,37 @@ package org.opensearch.common.blobstore.stream.read.listener; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.io.Channels; import org.opensearch.common.io.InputStreamContainer; -import org.opensearch.core.action.ActionListener; import java.io.IOException; import java.io.InputStream; import java.nio.channels.FileChannel; -import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; +import java.util.function.UnaryOperator; /** * FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel} * instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion. */ @InternalApi -class FilePartWriter implements BiConsumer { - - private final int partNumber; - private final Path fileLocation; - private final AtomicBoolean anyPartStreamFailed; - private final ActionListener fileCompletionListener; - private static final Logger logger = LogManager.getLogger(FilePartWriter.class); - +class FilePartWriter { // 8 MB buffer for transfer private static final int BUFFER_SIZE = 8 * 1024 * 2024; - public FilePartWriter( - int partNumber, - Path fileLocation, - AtomicBoolean anyPartStreamFailed, - ActionListener fileCompletionListener - ) { - this.partNumber = partNumber; - this.fileLocation = fileLocation; - this.anyPartStreamFailed = anyPartStreamFailed; - this.fileCompletionListener = fileCompletionListener; - } - - @Override - public void accept(InputStreamContainer blobPartStreamContainer, Throwable throwable) { - if (throwable != null) { - if (throwable instanceof Exception) { - processFailure((Exception) throwable); - } else { - processFailure(new Exception(throwable)); - } - return; - } - // Ensures no writes to the file if any stream fails. - if (anyPartStreamFailed.get() == false) { - try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { - try (InputStream inputStream = blobPartStreamContainer.getInputStream()) { - long streamOffset = blobPartStreamContainer.getOffset(); - final byte[] buffer = new byte[BUFFER_SIZE]; - int bytesRead; - while ((bytesRead = inputStream.read(buffer)) != -1) { - Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset); - streamOffset += bytesRead; - } + public static void write(Path fileLocation, InputStreamContainer stream, UnaryOperator rateLimiter) throws IOException { + try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { + try (InputStream inputStream = rateLimiter.apply(stream.getInputStream())) { + long streamOffset = stream.getOffset(); + final byte[] buffer = new byte[BUFFER_SIZE]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset); + streamOffset += bytesRead; } - } catch (IOException e) { - processFailure(e); - return; } - fileCompletionListener.onResponse(partNumber); - } - } - - void processFailure(Exception e) { - try { - Files.deleteIfExists(fileLocation); - } catch (IOException ex) { - // Die silently - logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex); - } - if (anyPartStreamFailed.getAndSet(true) == false) { - fileCompletionListener.onFailure(e); } } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java index 4aa028fd6e7cc..4946b3b684c23 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -10,12 +10,21 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.GroupedActionListener; import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; import java.nio.file.Path; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.UnaryOperator; /** * ReadContextListener orchestrates the async file fetch from the {@link org.opensearch.common.blobstore.BlobContainer} @@ -23,29 +32,53 @@ */ @InternalApi public class ReadContextListener implements ActionListener { + private static final Logger logger = LogManager.getLogger(ReadContextListener.class); - private final String fileName; + private final String blobName; private final Path fileLocation; private final ActionListener completionListener; - private static final Logger logger = LogManager.getLogger(ReadContextListener.class); + private final ThreadPool threadPool; + private final UnaryOperator rateLimiter; - public ReadContextListener(String fileName, Path fileLocation, ActionListener completionListener) { - this.fileName = fileName; + public ReadContextListener( + String blobName, + Path fileLocation, + ActionListener completionListener, + ThreadPool threadPool, + UnaryOperator rateLimter + ) { + this.blobName = blobName; this.fileLocation = fileLocation; this.completionListener = completionListener; + this.threadPool = threadPool; + this.rateLimiter = rateLimter; } @Override public void onResponse(ReadContext readContext) { - logger.trace("Streams received for blob {}", fileName); + logger.debug("Received {} parts for blob {}", readContext.getNumberOfParts(), blobName); final int numParts = readContext.getNumberOfParts(); - final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener); - - for (int partNumber = 0; partNumber < numParts; partNumber++) { - readContext.getPartStreams() - .get(partNumber) - .whenComplete(new FilePartWriter(partNumber, fileLocation, anyPartStreamFailed, fileCompletionListener)); + final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(false); + final GroupedActionListener groupedListener = new GroupedActionListener<>( + ActionListener.wrap(r -> completionListener.onResponse(blobName), completionListener::onFailure), + numParts + ); + final Queue queue = new ConcurrentLinkedQueue<>(readContext.getPartStreams()); + final StreamPartProcessor processor = new StreamPartProcessor( + queue, + anyPartStreamFailed, + fileLocation, + groupedListener, + threadPool.executor(ThreadPool.Names.REMOTE_RECOVERY), + rateLimiter + ); + // We have a dedicated thread pool for draining streams and writing them to a file on the + // local disk. However, we want more concurrent requests to fetch stream parts to be + // happening asynchronously so that there is always a queue of streams ready to be consumed + // by the REMOTE_RECOVERY thread pool. + final int concurrentStreamRequests = 20; + for (int i = 0; i < Math.min(concurrentStreamRequests, queue.size()); i++) { + processor.process(queue.poll()); } } @@ -53,4 +86,79 @@ public void onResponse(ReadContext readContext) { public void onFailure(Exception e) { completionListener.onFailure(e); } + + private static class StreamPartProcessor { + private static final RuntimeException CANCELED_PART_EXCEPTION = new RuntimeException( + "Canceled part download due to previous failure" + ); + private final Queue queue; + private final AtomicBoolean anyPartStreamFailed; + private final Path fileLocation; + private final GroupedActionListener completionListener; + private final Executor executor; + private final UnaryOperator rateLimiter; + + private StreamPartProcessor( + Queue queue, + AtomicBoolean anyPartStreamFailed, + Path fileLocation, + GroupedActionListener completionListener, + Executor executor, + UnaryOperator rateLimiter + ) { + this.queue = queue; + this.anyPartStreamFailed = anyPartStreamFailed; + this.fileLocation = fileLocation; + this.completionListener = completionListener; + this.executor = executor; + this.rateLimiter = rateLimiter; + } + + private void process(ReadContext.StreamPartCreator supplier) { + if (supplier == null) { + return; + } + supplier.get().whenCompleteAsync((blobPartStreamContainer, throwable) -> { + if (throwable != null) { + processFailure(throwable instanceof Exception ? (Exception) throwable : new RuntimeException(throwable)); + } else if (anyPartStreamFailed.get()) { + processFailure(CANCELED_PART_EXCEPTION); + } else { + try { + FilePartWriter.write(fileLocation, blobPartStreamContainer, rateLimiter); + completionListener.onResponse(fileLocation.toString()); + + // Upon successfully completing a file part, pull another + // file part off the queue to trigger asynchronous processing + process(queue.poll()); + } catch (Exception e) { + processFailure(e); + } + } + }, executor); + } + + private void processFailure(Exception e) { + if (anyPartStreamFailed.getAndSet(true) == false) { + completionListener.onFailure(e); + + // Drain the queue of pending part downloads. These can be discarded + // since they haven't started any work yet, but the listener must be + // notified for each part. + Object item = queue.poll(); + while (item != null) { + completionListener.onFailure(CANCELED_PART_EXCEPTION); + item = queue.poll(); + } + } else { + completionListener.onFailure(e); + } + try { + Files.deleteIfExists(fileLocation); + } catch (IOException ex) { + // Die silently + logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex); + } + } + } } diff --git a/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java b/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java index 594b7f99cd85a..eb75c39532d71 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java @@ -62,9 +62,9 @@ public class RemoteDirectory extends Directory { protected final BlobContainer blobContainer; private static final Logger logger = LogManager.getLogger(RemoteDirectory.class); - protected final UnaryOperator uploadRateLimiter; + private final UnaryOperator uploadRateLimiter; - protected final UnaryOperator downloadRateLimiter; + private final UnaryOperator downloadRateLimiter; /** * Number of bytes in the segment file to store checksum @@ -333,6 +333,10 @@ public boolean copyFrom( return false; } + protected UnaryOperator getDownloadRateLimiter() { + return downloadRateLimiter; + } + private void uploadBlob( Directory from, String src, diff --git a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java index d347479a3bf25..2b3b495ca9d1b 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java @@ -25,6 +25,7 @@ import org.apache.lucene.util.Version; import org.opensearch.common.UUIDs; import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.VersionedCodecStreamWrapper; import org.opensearch.common.logging.Loggers; @@ -488,7 +489,14 @@ public void copyTo(String source, Directory destinationDirectory, Path destinati if (destinationPath != null && remoteDataDirectory.getBlobContainer() instanceof AsyncMultiStreamBlobContainer) { final AsyncMultiStreamBlobContainer blobContainer = (AsyncMultiStreamBlobContainer) remoteDataDirectory.getBlobContainer(); final Path destinationFilePath = destinationPath.resolve(source); - blobContainer.asyncBlobDownload(blobName, destinationFilePath, fileCompletionListener); + ReadContextListener readContextListener = new ReadContextListener( + blobName, + destinationFilePath, + fileCompletionListener, + threadPool, + remoteDataDirectory.getDownloadRateLimiter() + ); + blobContainer.readBlobAsync(blobName, readContextListener); } else { // Fallback to older mechanism of downloading the file try { diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index 8375ac34972af..ecb5b2cef58ac 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -115,6 +115,7 @@ public static class Names { public static final String TRANSLOG_SYNC = "translog_sync"; public static final String REMOTE_PURGE = "remote_purge"; public static final String REMOTE_REFRESH_RETRY = "remote_refresh_retry"; + public static final String REMOTE_RECOVERY = "remote_recovery"; public static final String INDEX_SEARCHER = "index_searcher"; } @@ -184,6 +185,7 @@ public static ThreadPoolType fromType(String type) { map.put(Names.TRANSLOG_SYNC, ThreadPoolType.FIXED); map.put(Names.REMOTE_PURGE, ThreadPoolType.SCALING); map.put(Names.REMOTE_REFRESH_RETRY, ThreadPoolType.SCALING); + map.put(Names.REMOTE_RECOVERY, ThreadPoolType.SCALING); if (FeatureFlags.isEnabled(FeatureFlags.CONCURRENT_SEGMENT_SEARCH)) { map.put(Names.INDEX_SEARCHER, ThreadPoolType.RESIZABLE); } @@ -269,6 +271,10 @@ public ThreadPool( Names.REMOTE_REFRESH_RETRY, new ScalingExecutorBuilder(Names.REMOTE_REFRESH_RETRY, 1, halfProcMaxAt10, TimeValue.timeValueMinutes(5)) ); + builders.put( + Names.REMOTE_RECOVERY, + new ScalingExecutorBuilder(Names.REMOTE_RECOVERY, 1, halfProcMaxAt10, TimeValue.timeValueMinutes(5)) + ); if (FeatureFlags.isEnabled(FeatureFlags.CONCURRENT_SEGMENT_SEARCH)) { builders.put( Names.INDEX_SEARCHER, diff --git a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java index 99b21fb26ea2c..1780819390052 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java @@ -57,7 +57,7 @@ public void testReadBlobAsync() throws Exception { final ListenerTestUtils.CountingCompletionListener completionListener = new ListenerTestUtils.CountingCompletionListener<>(); final CompletableFuture streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer); - final ReadContext readContext = new ReadContext(size, List.of(streamContainerFuture), null); + final ReadContext readContext = new ReadContext(size, List.of(() -> streamContainerFuture), null); Mockito.doAnswer(invocation -> { ActionListener readContextActionListener = invocation.getArgument(1); @@ -79,7 +79,7 @@ public void testReadBlobAsync() throws Exception { assertEquals(1, response.getNumberOfParts()); assertEquals(size, response.getBlobSize()); - InputStreamContainer responseContainer = response.getPartStreams().get(0).get(); + InputStreamContainer responseContainer = response.getPartStreams().get(0).get().join(); assertEquals(0, responseContainer.getOffset()); assertEquals(size, responseContainer.getContentLength()); assertEquals(100, responseContainer.getInputStream().available()); @@ -103,7 +103,7 @@ public void testReadBlobAsyncException() throws Exception { final ListenerTestUtils.CountingCompletionListener completionListener = new ListenerTestUtils.CountingCompletionListener<>(); final CompletableFuture streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer); - final ReadContext readContext = new ReadContext(size, List.of(streamContainerFuture), null); + final ReadContext readContext = new ReadContext(size, List.of(() -> streamContainerFuture), null); Mockito.doAnswer(invocation -> { ActionListener readContextActionListener = invocation.getArgument(1); diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java index fa13d90f42fa6..f56e97a3b0c7c 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java @@ -10,49 +10,6 @@ import org.opensearch.test.OpenSearchTestCase; -import java.io.IOException; - -import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; - public class FileCompletionListenerTests extends OpenSearchTestCase { - public void testFileCompletionListener() { - int numStreams = 10; - String fileName = "test_segment_file"; - CountingCompletionListener completionListener = new CountingCompletionListener(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); - - for (int stream = 0; stream < numStreams; stream++) { - // Ensure completion listener called only when all streams are completed - assertEquals(0, completionListener.getResponseCount()); - fileCompletionListener.onResponse(null); - } - - assertEquals(1, completionListener.getResponseCount()); - assertEquals(fileName, completionListener.getResponse()); - } - - public void testFileCompletionListenerFailure() { - int numStreams = 10; - String fileName = "test_segment_file"; - CountingCompletionListener completionListener = new CountingCompletionListener(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); - - // Fail the listener initially - IOException exception = new IOException(); - fileCompletionListener.onFailure(exception); - - for (int stream = 0; stream < numStreams - 1; stream++) { - assertEquals(0, completionListener.getResponseCount()); - fileCompletionListener.onResponse(null); - } - - assertEquals(1, completionListener.getFailureCount()); - assertEquals(exception, completionListener.getException()); - assertEquals(0, completionListener.getResponseCount()); - - fileCompletionListener.onFailure(exception); - assertEquals(2, completionListener.getFailureCount()); - assertEquals(exception, completionListener.getException()); - } } diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java index a2917feb00b10..f2a758b9bbe10 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java @@ -13,14 +13,11 @@ import org.junit.Before; import java.io.ByteArrayInputStream; -import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.util.UUID; -import java.util.concurrent.atomic.AtomicBoolean; - -import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; +import java.util.function.UnaryOperator; public class FilePartWriterTests extends OpenSearchTestCase { @@ -34,97 +31,37 @@ public void init() throws Exception { public void testFilePartWriter() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); int contentLength = 100; - int partNumber = 1; InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); - filePartWriter.accept(inputStreamContainer, null); + FilePartWriter.write(segmentFilePath, inputStreamContainer, UnaryOperator.identity()); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength, Files.size(segmentFilePath)); - assertEquals(1, fileCompletionListener.getResponseCount()); - assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); } public void testFilePartWriterWithOffset() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); int contentLength = 100; int offset = 10; - int partNumber = 1; InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), offset); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); - filePartWriter.accept(inputStreamContainer, null); + FilePartWriter.write(segmentFilePath, inputStreamContainer, UnaryOperator.identity()); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength + offset, Files.size(segmentFilePath)); - assertEquals(1, fileCompletionListener.getResponseCount()); - assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); } public void testFilePartWriterLargeInput() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); int contentLength = 20 * 1024 * 1024; - int partNumber = 1; InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); - filePartWriter.accept(inputStreamContainer, null); + FilePartWriter.write(segmentFilePath, inputStreamContainer, UnaryOperator.identity()); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength, Files.size(segmentFilePath)); - - assertEquals(1, fileCompletionListener.getResponseCount()); - assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); - } - - public void testFilePartWriterException() throws Exception { - Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); - int partNumber = 1; - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - - IOException ioException = new IOException(); - FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); - assertFalse(anyStreamFailed.get()); - filePartWriter.processFailure(ioException); - - assertTrue(anyStreamFailed.get()); - assertFalse(Files.exists(segmentFilePath)); - - // Fail stream again to simulate another stream failure for same file - filePartWriter.processFailure(ioException); - - assertTrue(anyStreamFailed.get()); - assertFalse(Files.exists(segmentFilePath)); - - assertEquals(0, fileCompletionListener.getResponseCount()); - assertEquals(1, fileCompletionListener.getFailureCount()); - assertEquals(ioException, fileCompletionListener.getException()); - } - - public void testFilePartWriterStreamFailed() throws Exception { - Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); - int contentLength = 100; - int partNumber = 1; - InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); - InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(true); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - - FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); - filePartWriter.accept(inputStreamContainer, null); - - assertFalse(Files.exists(segmentFilePath)); - assertEquals(0, fileCompletionListener.getResponseCount()); } } diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java index b2ae0d20e7486..74f22abfd7713 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -31,6 +31,7 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.function.UnaryOperator; import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; @@ -65,10 +66,16 @@ public void init() throws Exception { public void testReadContextListener() throws InterruptedException, IOException { Path fileLocation = path.resolve(UUID.randomUUID().toString()); - List> blobPartStreams = initializeBlobPartStreams(); + List blobPartStreams = initializeBlobPartStreams(); CountDownLatch countDownLatch = new CountDownLatch(1); ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, completionListener); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity() + ); ReadContext readContext = new ReadContext((long) PART_SIZE * NUMBER_OF_PARTS, blobPartStreams, null); readContextListener.onResponse(readContext); @@ -80,10 +87,16 @@ public void testReadContextListener() throws InterruptedException, IOException { public void testReadContextListenerFailure() throws Exception { Path fileLocation = path.resolve(UUID.randomUUID().toString()); - List> blobPartStreams = initializeBlobPartStreams(); + List blobPartStreams = initializeBlobPartStreams(); CountDownLatch countDownLatch = new CountDownLatch(1); ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, completionListener); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity() + ); InputStream badInputStream = new InputStream() { @Override @@ -104,7 +117,7 @@ public int available() { blobPartStreams.add( NUMBER_OF_PARTS, - CompletableFuture.supplyAsync( + () -> CompletableFuture.supplyAsync( () -> new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS), threadPool.generic() ) @@ -119,20 +132,26 @@ public int available() { public void testReadContextListenerException() { Path fileLocation = path.resolve(UUID.randomUUID().toString()); CountingCompletionListener listener = new CountingCompletionListener(); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, listener); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + listener, + threadPool, + UnaryOperator.identity() + ); IOException exception = new IOException(); readContextListener.onFailure(exception); assertEquals(1, listener.getFailureCount()); assertEquals(exception, listener.getException()); } - private List> initializeBlobPartStreams() { - List> blobPartStreams = new ArrayList<>(); + private List initializeBlobPartStreams() { + List blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) { InputStream testStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)); int finalPartNumber = partNumber; blobPartStreams.add( - CompletableFuture.supplyAsync( + () -> CompletableFuture.supplyAsync( () -> new InputStreamContainer(testStream, PART_SIZE, (long) finalPartNumber * PART_SIZE), threadPool.generic() ) diff --git a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java index e5666c58b039f..77ebcea3dfe32 100644 --- a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java +++ b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java @@ -25,12 +25,13 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.UUIDs; import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; +import org.opensearch.common.io.InputStreamContainer; import org.opensearch.common.io.VersionedCodecStreamWrapper; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.lucene.store.ByteArrayIndexInput; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; @@ -56,9 +57,11 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.function.UnaryOperator; import org.mockito.Mockito; @@ -151,7 +154,9 @@ public void setup() throws IOException { segmentInfos = store.readLastCommittedSegmentsInfo(); } + when(remoteDataDirectory.getDownloadRateLimiter()).thenReturn(UnaryOperator.identity()); when(threadPool.executor(ThreadPool.Names.REMOTE_PURGE)).thenReturn(executorService); + when(threadPool.executor(ThreadPool.Names.REMOTE_RECOVERY)).thenReturn(executorService); } @After @@ -562,9 +567,6 @@ public void onFailure(Exception e) {} } public void testCopyFilesToMultipart() throws Exception { - Settings settings = Settings.builder().build(); - FeatureFlags.initializeFeatureFlags(settings); - String filename = "_0.cfe"; populateMetadata(); remoteSegmentStoreDirectory.init(); @@ -574,13 +576,15 @@ public void testCopyFilesToMultipart() throws Exception { when(remoteDataDirectory.getBlobContainer()).thenReturn(blobContainer); Mockito.doAnswer(invocation -> { - ActionListener completionListener = invocation.getArgument(2); - completionListener.onResponse(invocation.getArgument(0)); + ActionListener completionListener = invocation.getArgument(1); + final CompletableFuture future = new CompletableFuture<>(); + future.complete(new InputStreamContainer(new ByteArrayInputStream(new byte[] { 42 }), 0, 1)); + completionListener.onResponse(new ReadContext(1, List.of(() -> future), "")); return null; - }).when(blobContainer).asyncBlobDownload(any(), any(), any()); + }).when(blobContainer).readBlobAsync(any(), any()); CountDownLatch downloadLatch = new CountDownLatch(1); - ActionListener completionListener = new ActionListener() { + ActionListener completionListener = new ActionListener<>() { @Override public void onResponse(String unused) { downloadLatch.countDown(); @@ -592,7 +596,7 @@ public void onFailure(Exception e) {} Path path = createTempDir(); remoteSegmentStoreDirectory.copyTo(filename, storeDirectory, path, completionListener); assertTrue(downloadLatch.await(5000, TimeUnit.SECONDS)); - verify(blobContainer, times(1)).asyncBlobDownload(contains(filename), eq(path.resolve(filename)), any()); + verify(blobContainer, times(1)).readBlobAsync(contains(filename), any()); verify(storeDirectory, times(0)).copyFrom(any(), any(), any(), any()); }