Skip to content

Commit

Permalink
Refactor read context streams to async streams
Browse files Browse the repository at this point in the history
Signed-off-by: Kunal Kotwani <[email protected]>
  • Loading branch information
kotwanikunal authored and andrross committed Oct 3, 2023
1 parent bddf0d3 commit 2c51a10
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -254,24 +254,7 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber));
}
}

CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new))
.whenComplete((unused, partThrowable) -> {
if (partThrowable == null) {
listener.onResponse(
new ReadContext(
blobSize,
blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()),
blobChecksum
)
);
} else {
Exception ex = partThrowable.getCause() instanceof Exception
? (Exception) partThrowable.getCause()
: new Exception(partThrowable.getCause());
listener.onFailure(ex);
}
});
listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum));
});
} catch (Exception ex) {
listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ public void testReadBlobAsyncMultiPart() throws Exception {
assertEquals(objectSize, readContext.getBlobSize());

for (int partNumber = 1; partNumber < objectPartCount; partNumber++) {
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber);
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get();
final int offset = partNumber * partSize;
assertEquals(partSize, inputStreamContainer.getContentLength());
assertEquals(offset, inputStreamContainer.getOffset());
Expand Down Expand Up @@ -1024,7 +1024,7 @@ public void testReadBlobAsyncSinglePart() throws Exception {
assertEquals(checksum, readContext.getBlobChecksum());
assertEquals(objectSize, readContext.getBlobSize());

InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get();
InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get();
assertEquals(objectSize, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -124,11 +125,11 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
long contentLength = listBlobs().get(blobName).length();
long partSize = contentLength / 10;
int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1);
List<InputStreamContainer> blobPartStreams = new ArrayList<>();
List<CompletableFuture<InputStreamContainer>> blobPartStreams = new ArrayList<>();
for (int partNumber = 0; partNumber < numberOfParts; partNumber++) {
long offset = partNumber * partSize;
InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset);
blobPartStreams.add(blobPartStream);
blobPartStreams.add(CompletableFuture.completedFuture(blobPartStream));
}
ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null);
listener.onResponse(blobReadContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.nio.file.Path;
Expand Down Expand Up @@ -49,12 +48,11 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer {
* Asynchronously downloads the blob to the specified location using an executor from the thread pool.
* @param blobName The name of the blob for which needs to be downloaded.
* @param fileLocation The path on local disk where the blob needs to be downloaded.
* @param threadPool The threadpool instance which will provide the executor for performing a multipart download.
* @param completionListener Listener which will be notified when the download is complete.
*/
@ExperimentalApi
default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener);
default void asyncBlobDownload(String blobName, Path fileLocation, ActionListener<String> completionListener) {
ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, completionListener);
readBlobAsync(blobName, readContextListener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -144,8 +145,10 @@ public long getBlobSize() {
}

@Override
public List<InputStreamContainer> getPartStreams() {
return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList());
public List<CompletableFuture<InputStreamContainer>> getPartStreams() {
return super.getPartStreams().stream()
.map(cf -> cf.thenApply(this::decryptInputStreamContainer))
.collect(Collectors.toUnmodifiableList());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,26 @@
import org.opensearch.common.io.InputStreamContainer;

import java.util.List;
import java.util.concurrent.CompletableFuture;

/**
* ReadContext is used to encapsulate all data needed by <code>BlobContainer#readBlobAsync</code>
*/
@ExperimentalApi
public class ReadContext {
private final long blobSize;
private final List<InputStreamContainer> partStreams;
private final List<CompletableFuture<InputStreamContainer>> asyncPartStreams;
private final String blobChecksum;

public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String blobChecksum) {
public ReadContext(long blobSize, List<CompletableFuture<InputStreamContainer>> asyncPartStreams, String blobChecksum) {
this.blobSize = blobSize;
this.partStreams = partStreams;
this.asyncPartStreams = asyncPartStreams;
this.blobChecksum = blobChecksum;
}

public ReadContext(ReadContext readContext) {
this.blobSize = readContext.blobSize;
this.partStreams = readContext.partStreams;
this.asyncPartStreams = readContext.asyncPartStreams;
this.blobChecksum = readContext.blobChecksum;
}

Expand All @@ -39,14 +40,14 @@ public String getBlobChecksum() {
}

public int getNumberOfParts() {
return partStreams.size();
return asyncPartStreams.size();
}

public long getBlobSize() {
return blobSize;
}

public List<InputStreamContainer> getPartStreams() {
return partStreams;
public List<CompletableFuture<InputStreamContainer>> getPartStreams() {
return asyncPartStreams;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;

/**
* FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel}
* instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion.
*/
@InternalApi
class FilePartWriter implements Runnable {
class FilePartWriter implements BiConsumer<InputStreamContainer, Throwable> {

private final int partNumber;
private final InputStreamContainer blobPartStreamContainer;
private final Path fileLocation;
private final AtomicBoolean anyPartStreamFailed;
private final ActionListener<Integer> fileCompletionListener;
Expand All @@ -42,20 +42,26 @@ class FilePartWriter implements Runnable {

public FilePartWriter(
int partNumber,
InputStreamContainer blobPartStreamContainer,
Path fileLocation,
AtomicBoolean anyPartStreamFailed,
ActionListener<Integer> fileCompletionListener
) {
this.partNumber = partNumber;
this.blobPartStreamContainer = blobPartStreamContainer;
this.fileLocation = fileLocation;
this.anyPartStreamFailed = anyPartStreamFailed;
this.fileCompletionListener = fileCompletionListener;
}

@Override
public void run() {
public void accept(InputStreamContainer blobPartStreamContainer, Throwable throwable) {
if (throwable != null) {
if (throwable instanceof Exception) {
processFailure((Exception) throwable);
} else {
processFailure(new Exception(throwable));
}
return;
}
// Ensures no writes to the file if any stream fails.
if (anyPartStreamFailed.get() == false) {
try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,25 @@
import org.opensearch.common.annotation.InternalApi;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.nio.file.Path;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* ReadContextListener orchestrates the async file fetch from the {@link org.opensearch.common.blobstore.BlobContainer}
* using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams which are
* spread across a {@link ThreadPool} executor.
* using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams.
*/
@InternalApi
public class ReadContextListener implements ActionListener<ReadContext> {

private final String fileName;
private final Path fileLocation;
private final ThreadPool threadPool;
private final ActionListener<String> completionListener;
private static final Logger logger = LogManager.getLogger(ReadContextListener.class);

public ReadContextListener(String fileName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
public ReadContextListener(String fileName, Path fileLocation, ActionListener<String> completionListener) {
this.fileName = fileName;
this.fileLocation = fileLocation;
this.threadPool = threadPool;
this.completionListener = completionListener;
}

Expand All @@ -47,14 +43,9 @@ public void onResponse(ReadContext readContext) {
FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener);

for (int partNumber = 0; partNumber < numParts; partNumber++) {
FilePartWriter filePartWriter = new FilePartWriter(
partNumber,
readContext.getPartStreams().get(partNumber),
fileLocation,
anyPartStreamFailed,
fileCompletionListener
);
threadPool.executor(ThreadPool.Names.GENERIC).submit(filePartWriter);
readContext.getPartStreams()
.get(partNumber)
.whenComplete(new FilePartWriter(partNumber, fileLocation, anyPartStreamFailed, fileCompletionListener));
}
}

Expand Down
24 changes: 8 additions & 16 deletions server/src/main/java/org/opensearch/index/shard/IndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import org.opensearch.action.admin.indices.flush.FlushRequest;
import org.opensearch.action.admin.indices.forcemerge.ForceMergeRequest;
import org.opensearch.action.admin.indices.upgrade.post.UpgradeRequest;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.support.replication.PendingReplicationActions;
import org.opensearch.action.support.replication.ReplicationResponse;
Expand Down Expand Up @@ -4916,24 +4915,17 @@ private void downloadSegments(
RemoteSegmentStoreDirectory targetRemoteDirectory,
Set<String> toDownloadSegments,
final Runnable onFileSync
) {
final PlainActionFuture<Void> completionListener = PlainActionFuture.newFuture();
final GroupedActionListener<Void> batchDownloadListener = new GroupedActionListener<>(
ActionListener.map(completionListener, v -> null),
toDownloadSegments.size()
);

final ActionListener<String> segmentsDownloadListener = ActionListener.map(batchDownloadListener, fileName -> {
) throws IOException {
final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex();
for (String segment : toDownloadSegments) {
final PlainActionFuture<String> segmentListener = PlainActionFuture.newFuture();
sourceRemoteDirectory.copyTo(segment, storeDirectory, indexPath, segmentListener);
segmentListener.actionGet();
onFileSync.run();
if (targetRemoteDirectory != null) {
targetRemoteDirectory.copyFrom(storeDirectory, fileName, fileName, IOContext.DEFAULT);
targetRemoteDirectory.copyFrom(storeDirectory, segment, segment, IOContext.DEFAULT);
}
return null;
});

final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex();
toDownloadSegments.forEach(file -> { sourceRemoteDirectory.copyTo(file, storeDirectory, indexPath, segmentsDownloadListener); });
completionListener.actionGet();
}
}

private boolean localDirectoryContains(Directory localDirectory, String file, long checksum) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ public void copyTo(String source, Directory destinationDirectory, Path destinati
if (destinationPath != null && remoteDataDirectory.getBlobContainer() instanceof AsyncMultiStreamBlobContainer) {
final AsyncMultiStreamBlobContainer blobContainer = (AsyncMultiStreamBlobContainer) remoteDataDirectory.getBlobContainer();
final Path destinationFilePath = destinationPath.resolve(source);
blobContainer.asyncBlobDownload(blobName, destinationFilePath, threadPool, fileCompletionListener);
blobContainer.asyncBlobDownload(blobName, destinationFilePath, fileCompletionListener);
} else {
// Fallback to older mechanism of downloading the file
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.util.Version;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.concurrent.GatedCloseable;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.shard.IndexShard;
Expand Down Expand Up @@ -141,14 +141,12 @@ private void downloadSegments(
ActionListener<GetSegmentFilesResponse> completionListener
) {
final Path indexPath = shardPath == null ? null : shardPath.resolveIndex();
final GroupedActionListener<Void> batchDownloadListener = new GroupedActionListener<>(
ActionListener.map(completionListener, v -> new GetSegmentFilesResponse(toDownloadSegments)),
toDownloadSegments.size()
);
ActionListener<String> segmentsDownloadListener = ActionListener.map(batchDownloadListener, result -> null);
toDownloadSegments.forEach(
fileMetadata -> remoteStoreDirectory.copyTo(fileMetadata.name(), storeDirectory, indexPath, segmentsDownloadListener)
);
for (StoreFileMetadata storeFileMetadata : toDownloadSegments) {
final PlainActionFuture<String> segmentListener = PlainActionFuture.newFuture();
remoteStoreDirectory.copyTo(storeFileMetadata.name(), storeDirectory, indexPath, segmentListener);
segmentListener.actionGet();
}
completionListener.onResponse(new GetSegmentFilesResponse(toDownloadSegments));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.UnaryOperator;

import org.mockito.Mockito;
Expand Down Expand Up @@ -51,10 +52,12 @@ public void testReadBlobAsync() throws Exception {
// Objects needed for API call
final byte[] data = new byte[size];
Randomness.get().nextBytes(data);

final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);
final CompletableFuture<InputStreamContainer> streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer);
final ReadContext readContext = new ReadContext(size, List.of(streamContainerFuture), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
Expand All @@ -76,7 +79,7 @@ public void testReadBlobAsync() throws Exception {
assertEquals(1, response.getNumberOfParts());
assertEquals(size, response.getBlobSize());

InputStreamContainer responseContainer = response.getPartStreams().get(0);
InputStreamContainer responseContainer = response.getPartStreams().get(0).get();
assertEquals(0, responseContainer.getOffset());
assertEquals(size, responseContainer.getContentLength());
assertEquals(100, responseContainer.getInputStream().available());
Expand All @@ -99,7 +102,8 @@ public void testReadBlobAsyncException() throws Exception {
final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);
final CompletableFuture<InputStreamContainer> streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer);
final ReadContext readContext = new ReadContext(size, List.of(streamContainerFuture), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
Expand Down
Loading

0 comments on commit 2c51a10

Please sign in to comment.