diff --git a/jflyte/src/main/java/org/flyte/jflyte/ArtifactStager.java b/jflyte/src/main/java/org/flyte/jflyte/ArtifactStager.java index 5f7299f4e..73434ee4b 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ArtifactStager.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ArtifactStager.java @@ -33,10 +33,13 @@ import java.net.URISyntaxException; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; -import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; import org.flyte.jflyte.api.FileSystem; import org.flyte.jflyte.api.Manifest; import org.slf4j.Logger; @@ -58,13 +61,16 @@ class ArtifactStager { private final String stagingLocation; private final FileSystem fileSystem; + private final ExecutorService executorService; - ArtifactStager(String stagingLocation, FileSystem fileSystem) { + ArtifactStager(String stagingLocation, FileSystem fileSystem, ExecutorService executorService) { this.stagingLocation = stagingLocation; this.fileSystem = fileSystem; + this.executorService = executorService; } - static ArtifactStager create(Config config, Collection modules) { + static ArtifactStager create( + Config config, Collection modules, ExecutorService executorService) { try { String stagingLocation = config.stagingLocation(); @@ -77,36 +83,43 @@ static ArtifactStager create(Config config, Collection modules) { Map fileSystems = FileSystemLoader.loadFileSystems(modules); FileSystem stagingFileSystem = FileSystemLoader.getFileSystem(fileSystems, stagingUri); - return new ArtifactStager(stagingLocation, stagingFileSystem); + return new ArtifactStager(stagingLocation, stagingFileSystem, executorService); } catch (URISyntaxException e) { throw new IllegalArgumentException("Failed to parse stagingLocation", e); } } - List stageFiles(List files) { - List artifacts = new ArrayList<>(); + List stageFiles(List filePaths) { + List files = + filePaths.stream().map(ArtifactStager::toFileAndVerify).collect(Collectors.toList()); + List> stages = + files.stream().map(this::getArtifactForFile).collect(Collectors.toList()); - // TODO use multiple threads for better throughput - for (String filePath : files) { - File file = new File(filePath); - - verify(file.exists(), "file doesn't exist [%s]", filePath); - verify(!file.isDirectory(), "directories aren't supported [%s]", filePath); - - Artifact artifact = getArtifactForFile(file, stagingLocation); - stageArtifact(artifact, Files.asByteSource(file)); + return CompletableFutures.getAll(stages); + } - artifacts.add(artifact); - } + private static File toFileAndVerify(String filePath) { + File file = new File(filePath); + verify(file.exists(), "file doesn't exist [%s]", filePath); + verify(!file.isDirectory(), "directories aren't supported [%s]", filePath); + return file; + } - return artifacts; + private CompletionStage getArtifactForFile(File file) { + return CompletableFuture.supplyAsync( + () -> { + Artifact artifact = getArtifactForFile(file, stagingLocation); + stageArtifact(artifact, Files.asByteSource(file)); + return artifact; + }, + executorService); } void stageArtifact(Artifact artifact, ByteSource content) { - LOG.info("Staging [{}] to [{}]", artifact.name(), artifact.location()); - Manifest manifest = fileSystem.getManifest(artifact.location()); if (manifest == null) { + LOG.info("Staging [{}] to [{}]", artifact.name(), artifact.location()); + // TODO writer API should accept crc32c as an option to pass it to underlying implementation // that is going to double-check it once blob is uploaded @@ -116,6 +129,7 @@ void stageArtifact(Artifact artifact, ByteSource content) { throw new UncheckedIOException(e); } } else { + LOG.info("[{}] already staged to [{}]", artifact.name(), artifact.location()); // TODO check that crc32c matches } } diff --git a/jflyte/src/main/java/org/flyte/jflyte/CompletableFutures.java b/jflyte/src/main/java/org/flyte/jflyte/CompletableFutures.java new file mode 100644 index 000000000..ccaab795e --- /dev/null +++ b/jflyte/src/main/java/org/flyte/jflyte/CompletableFutures.java @@ -0,0 +1,46 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.jflyte; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; + +class CompletableFutures { + + static List getAll(List> stages) { + List result = new ArrayList<>(stages.size()); + + for (int i = 0; i < stages.size(); ++i) { + try { + result.add(stages.get(i).toCompletableFuture().get()); + } catch (InterruptedException | ExecutionException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + for (int j = i; j < stages.size(); ++j) { + stages.get(j).toCompletableFuture().cancel(true); + } + + throw new RuntimeException(e); + } + } + + return result; + } +} diff --git a/jflyte/src/main/java/org/flyte/jflyte/RegisterWorkflows.java b/jflyte/src/main/java/org/flyte/jflyte/RegisterWorkflows.java index c65f1906d..d124ee0cc 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/RegisterWorkflows.java +++ b/jflyte/src/main/java/org/flyte/jflyte/RegisterWorkflows.java @@ -20,6 +20,7 @@ import java.util.Collection; import java.util.concurrent.Callable; +import java.util.concurrent.ForkJoinPool; import java.util.function.Supplier; import javax.annotation.Nullable; import org.flyte.jflyte.api.TokenSource; @@ -73,7 +74,8 @@ public Integer call() { try (FlyteAdminClient adminClient = FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), tokenSource)) { - Supplier stagerSupplier = () -> ArtifactStager.create(config, modules); + Supplier stagerSupplier = + () -> ArtifactStager.create(config, modules, new ForkJoinPool()); ExecutionConfig executionConfig = ExecutionConfig.builder() diff --git a/jflyte/src/main/java/org/flyte/jflyte/SerializeWorkflows.java b/jflyte/src/main/java/org/flyte/jflyte/SerializeWorkflows.java index 4c60db80f..a743fbb65 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/SerializeWorkflows.java +++ b/jflyte/src/main/java/org/flyte/jflyte/SerializeWorkflows.java @@ -26,6 +26,7 @@ import java.io.UncheckedIOException; import java.util.Collection; import java.util.concurrent.Callable; +import java.util.concurrent.ForkJoinPool; import java.util.function.BiConsumer; import java.util.function.Supplier; import javax.annotation.Nullable; @@ -69,7 +70,8 @@ public Integer call() { try (FlyteAdminClient adminClient = FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), tokenSource)) { - Supplier stagerSupplier = () -> ArtifactStager.create(config, modules); + Supplier stagerSupplier = + () -> ArtifactStager.create(config, modules, new ForkJoinPool()); ExecutionConfig executionConfig = ExecutionConfig.builder() .domain(DOMAIN_PLACEHOLDER) diff --git a/jflyte/src/test/java/org/flyte/jflyte/CompletableFuturesTest.java b/jflyte/src/test/java/org/flyte/jflyte/CompletableFuturesTest.java new file mode 100644 index 000000000..bb660cf7f --- /dev/null +++ b/jflyte/src/test/java/org/flyte/jflyte/CompletableFuturesTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.jflyte; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.Test; + +class CompletableFuturesTest { + + @Test + void testGetAll() { + List> stages = new ArrayList<>(); + stages.add(CompletableFuture.completedFuture("foo")); + stages.add(CompletableFuture.completedFuture("bar")); + + assertThat(CompletableFutures.getAll(stages), contains("foo", "bar")); + } + + @Test + void testGetAllCancelled() { + RuntimeException expectedException = new RuntimeException(); + CompletableFuture failedFuture = + spy( + CompletableFuture.supplyAsync( + () -> { + throw expectedException; + })); + CompletableFuture shouldBeCancelledFuture = + spy(CompletableFuture.supplyAsync(() -> "foo")); + List> stages = new ArrayList<>(); + stages.add(failedFuture); + stages.add(shouldBeCancelledFuture); + + RuntimeException exception = + assertThrows(RuntimeException.class, () -> CompletableFutures.getAll(stages)); + assertThat(exception.getCause().getCause(), is(expectedException)); + verify(failedFuture).cancel(true); + verify(shouldBeCancelledFuture).cancel(true); + } + + @SuppressWarnings("unchecked") + @Test + void testGetAllInterruptedAndCancelled() throws ExecutionException, InterruptedException { + InterruptedException expectedException = new InterruptedException(); + CompletableFuture failedFuture = mock(CompletableFuture.class); + CompletionStage failedStage = mock(CompletionStage.class); + when(failedFuture.get()).thenThrow(expectedException); + when(failedStage.toCompletableFuture()).thenReturn(failedFuture); + + CompletableFuture shouldBeCancelledFuture = + spy(CompletableFuture.supplyAsync(() -> "foo")); + List> stages = new ArrayList<>(); + stages.add(failedStage); + stages.add(shouldBeCancelledFuture); + + RuntimeException exception = + assertThrows(RuntimeException.class, () -> CompletableFutures.getAll(stages)); + assertThat(exception.getCause(), is(expectedException)); + verify(failedFuture).cancel(true); + verify(shouldBeCancelledFuture).cancel(true); + + assertThat(Thread.currentThread().isInterrupted(), is(true)); + } +}