Skip to content

Commit

Permalink
Staging in parallel (#157)
Browse files Browse the repository at this point in the history
* Staging in parallel

Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix authored Dec 12, 2022
1 parent eb7b497 commit 072aa7b
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 22 deletions.
54 changes: 34 additions & 20 deletions jflyte/src/main/java/org/flyte/jflyte/ArtifactStager.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@
import java.net.URISyntaxException;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.flyte.jflyte.api.FileSystem;
import org.flyte.jflyte.api.Manifest;
import org.slf4j.Logger;
Expand All @@ -58,13 +61,16 @@ class ArtifactStager {

private final String stagingLocation;
private final FileSystem fileSystem;
private final ExecutorService executorService;

ArtifactStager(String stagingLocation, FileSystem fileSystem) {
ArtifactStager(String stagingLocation, FileSystem fileSystem, ExecutorService executorService) {
this.stagingLocation = stagingLocation;
this.fileSystem = fileSystem;
this.executorService = executorService;
}

static ArtifactStager create(Config config, Collection<ClassLoader> modules) {
static ArtifactStager create(
Config config, Collection<ClassLoader> modules, ExecutorService executorService) {
try {
String stagingLocation = config.stagingLocation();

Expand All @@ -77,36 +83,43 @@ static ArtifactStager create(Config config, Collection<ClassLoader> modules) {
Map<String, FileSystem> fileSystems = FileSystemLoader.loadFileSystems(modules);
FileSystem stagingFileSystem = FileSystemLoader.getFileSystem(fileSystems, stagingUri);

return new ArtifactStager(stagingLocation, stagingFileSystem);
return new ArtifactStager(stagingLocation, stagingFileSystem, executorService);
} catch (URISyntaxException e) {
throw new IllegalArgumentException("Failed to parse stagingLocation", e);
}
}

List<Artifact> stageFiles(List<String> files) {
List<Artifact> artifacts = new ArrayList<>();
List<Artifact> stageFiles(List<String> filePaths) {
List<File> files =
filePaths.stream().map(ArtifactStager::toFileAndVerify).collect(Collectors.toList());
List<CompletionStage<Artifact>> stages =
files.stream().map(this::getArtifactForFile).collect(Collectors.toList());

// TODO use multiple threads for better throughput
for (String filePath : files) {
File file = new File(filePath);

verify(file.exists(), "file doesn't exist [%s]", filePath);
verify(!file.isDirectory(), "directories aren't supported [%s]", filePath);

Artifact artifact = getArtifactForFile(file, stagingLocation);
stageArtifact(artifact, Files.asByteSource(file));
return CompletableFutures.getAll(stages);
}

artifacts.add(artifact);
}
private static File toFileAndVerify(String filePath) {
File file = new File(filePath);
verify(file.exists(), "file doesn't exist [%s]", filePath);
verify(!file.isDirectory(), "directories aren't supported [%s]", filePath);
return file;
}

return artifacts;
private CompletionStage<Artifact> getArtifactForFile(File file) {
return CompletableFuture.supplyAsync(
() -> {
Artifact artifact = getArtifactForFile(file, stagingLocation);
stageArtifact(artifact, Files.asByteSource(file));
return artifact;
},
executorService);
}

void stageArtifact(Artifact artifact, ByteSource content) {
LOG.info("Staging [{}] to [{}]", artifact.name(), artifact.location());

Manifest manifest = fileSystem.getManifest(artifact.location());
if (manifest == null) {
LOG.info("Staging [{}] to [{}]", artifact.name(), artifact.location());

// TODO writer API should accept crc32c as an option to pass it to underlying implementation
// that is going to double-check it once blob is uploaded

Expand All @@ -116,6 +129,7 @@ void stageArtifact(Artifact artifact, ByteSource content) {
throw new UncheckedIOException(e);
}
} else {
LOG.info("[{}] already staged to [{}]", artifact.name(), artifact.location());
// TODO check that crc32c matches
}
}
Expand Down
46 changes: 46 additions & 0 deletions jflyte/src/main/java/org/flyte/jflyte/CompletableFutures.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2021 Flyte Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.jflyte;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;

class CompletableFutures {

static <T> List<T> getAll(List<CompletionStage<T>> stages) {
List<T> result = new ArrayList<>(stages.size());

for (int i = 0; i < stages.size(); ++i) {
try {
result.add(stages.get(i).toCompletableFuture().get());
} catch (InterruptedException | ExecutionException e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
for (int j = i; j < stages.size(); ++j) {
stages.get(j).toCompletableFuture().cancel(true);
}

throw new RuntimeException(e);
}
}

return result;
}
}
4 changes: 3 additions & 1 deletion jflyte/src/main/java/org/flyte/jflyte/RegisterWorkflows.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.flyte.jflyte.api.TokenSource;
Expand Down Expand Up @@ -73,7 +74,8 @@ public Integer call() {

try (FlyteAdminClient adminClient =
FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), tokenSource)) {
Supplier<ArtifactStager> stagerSupplier = () -> ArtifactStager.create(config, modules);
Supplier<ArtifactStager> stagerSupplier =
() -> ArtifactStager.create(config, modules, new ForkJoinPool());

ExecutionConfig executionConfig =
ExecutionConfig.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.io.UncheckedIOException;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -69,7 +70,8 @@ public Integer call() {

try (FlyteAdminClient adminClient =
FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), tokenSource)) {
Supplier<ArtifactStager> stagerSupplier = () -> ArtifactStager.create(config, modules);
Supplier<ArtifactStager> stagerSupplier =
() -> ArtifactStager.create(config, modules, new ForkJoinPool());
ExecutionConfig executionConfig =
ExecutionConfig.builder()
.domain(DOMAIN_PLACEHOLDER)
Expand Down
91 changes: 91 additions & 0 deletions jflyte/src/test/java/org/flyte/jflyte/CompletableFuturesTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2021 Flyte Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.jflyte;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import org.junit.jupiter.api.Test;

class CompletableFuturesTest {

@Test
void testGetAll() {
List<CompletionStage<String>> stages = new ArrayList<>();
stages.add(CompletableFuture.completedFuture("foo"));
stages.add(CompletableFuture.completedFuture("bar"));

assertThat(CompletableFutures.getAll(stages), contains("foo", "bar"));
}

@Test
void testGetAllCancelled() {
RuntimeException expectedException = new RuntimeException();
CompletableFuture<String> failedFuture =
spy(
CompletableFuture.supplyAsync(
() -> {
throw expectedException;
}));
CompletableFuture<String> shouldBeCancelledFuture =
spy(CompletableFuture.supplyAsync(() -> "foo"));
List<CompletionStage<String>> stages = new ArrayList<>();
stages.add(failedFuture);
stages.add(shouldBeCancelledFuture);

RuntimeException exception =
assertThrows(RuntimeException.class, () -> CompletableFutures.getAll(stages));
assertThat(exception.getCause().getCause(), is(expectedException));
verify(failedFuture).cancel(true);
verify(shouldBeCancelledFuture).cancel(true);
}

@SuppressWarnings("unchecked")
@Test
void testGetAllInterruptedAndCancelled() throws ExecutionException, InterruptedException {
InterruptedException expectedException = new InterruptedException();
CompletableFuture<String> failedFuture = mock(CompletableFuture.class);
CompletionStage<String> failedStage = mock(CompletionStage.class);
when(failedFuture.get()).thenThrow(expectedException);
when(failedStage.toCompletableFuture()).thenReturn(failedFuture);

CompletableFuture<String> shouldBeCancelledFuture =
spy(CompletableFuture.supplyAsync(() -> "foo"));
List<CompletionStage<String>> stages = new ArrayList<>();
stages.add(failedStage);
stages.add(shouldBeCancelledFuture);

RuntimeException exception =
assertThrows(RuntimeException.class, () -> CompletableFutures.getAll(stages));
assertThat(exception.getCause(), is(expectedException));
verify(failedFuture).cancel(true);
verify(shouldBeCancelledFuture).cancel(true);

assertThat(Thread.currentThread().isInterrupted(), is(true));
}
}

0 comments on commit 072aa7b

Please sign in to comment.