From 9561bcbb13218bea8b301228ac31a8b5e9b5c583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nelson=20Arap=C3=A9?= Date: Mon, 23 Jan 2023 21:07:14 +0100 Subject: [PATCH] Fix conditionals and more MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nelson Arapé --- .../org/flyte/examples/AddQuestionTask.java | 1 - .../org/flyte/examples/AllInputsTask.java | 4 -- .../org/flyte/examples/AllInputsWorkflow.java | 9 +-- .../org/flyte/examples/BatchLookUpTask.java | 3 +- .../examples/ConditionalGreetingWorkflow.java | 5 +- .../examples/DynamicFibonacciWorkflow.java | 16 +++-- .../org/flyte/examples/FibonacciWorkflow.java | 22 ++++-- .../org/flyte/examples/PrintMessageTask.java | 1 - .../main/java/org/flyte/examples/SumTask.java | 1 - .../java/org/flyte/examples/UberWorkflow.java | 1 - .../org/flyte/examples/WelcomeWorkflow.java | 15 +++-- .../flyte/flytekit/SdkAppliedTransform.java | 67 +++++++++++++++++++ .../org/flyte/flytekit/SdkBranchNode.java | 17 +++-- .../java/org/flyte/flytekit/SdkCondition.java | 46 +++++++------ .../org/flyte/flytekit/SdkConditionCase.java | 8 +-- .../org/flyte/flytekit/SdkConditions.java | 21 ++++-- .../org/flyte/flytekit/SdkNodeNamePolicy.java | 5 +- .../java/org/flyte/flytekit/SdkTaskNode.java | 2 + .../main/java/org/flyte/flytekit/SdkType.java | 5 ++ .../java/org/flyte/flytekit/SdkTypes.java | 10 ++- .../flyte/flytekit/SdkWorkflowBuilder.java | 1 + .../org/flyte/flytekit/SdkWorkflowNode.java | 2 + .../flytekit/SdkWorkflowBuilderTest.java | 16 ++++- .../flyte/flytekitscala/SdkScalaType.scala | 4 +- .../flytekitscala/SdkScalaWorkflow.scala | 39 ++++++----- 25 files changed, 228 insertions(+), 93 deletions(-) create mode 100644 flytekit-java/src/main/java/org/flyte/flytekit/SdkAppliedTransform.java diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java b/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java index 7a83e8037..9c3803901 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java @@ -20,7 +20,6 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.jackson.JacksonSdkType; /** diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java index 00448b33b..46849fa7a 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java @@ -16,9 +16,6 @@ */ package org.flyte.examples; -import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; - import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import java.time.Duration; @@ -27,7 +24,6 @@ import java.util.Map; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkRunnableTask.class) diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java index 94918b1be..51bd1d53f 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java @@ -33,7 +33,8 @@ import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class AllInputsWorkflow extends SdkWorkflow { +public class AllInputsWorkflow + extends SdkWorkflow { public AllInputsWorkflow() { super(SdkTypes.nulls(), JacksonSdkType.of(AllInputsWorkflow.AllInputsWorkflowOutput.class)); @@ -47,7 +48,8 @@ public void expand(SdkWorkflowBuilder builder) { SdkNode apply = builder.apply( "all-inputs", - new AllInputsTask(), AllInputsTask.AutoAllInputsInput.create( + new AllInputsTask(), + AllInputsTask.AutoAllInputsInput.create( SdkBindingData.ofInteger(1L), SdkBindingData.ofFloat(2), SdkBindingData.ofString("test"), @@ -57,8 +59,7 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData.ofStringCollection(Arrays.asList("foo", "bar")), SdkBindingData.ofStringMap(Map.of("test", "test")), SdkBindingData.ofStringCollection(Collections.emptyList()), - SdkBindingData.ofIntegerMap(Collections.emptyMap())) - ); + SdkBindingData.ofIntegerMap(Collections.emptyMap()))); AllInputsTask.AutoAllInputsOutput outputs = apply.getOutputs(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java b/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java index ef08001ca..bd0889a75 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java @@ -55,7 +55,8 @@ public abstract static class Input { public abstract SdkBindingData> searchKeys(); - public static Input create(SdkBindingData> keyValues, SdkBindingData> searchKeys) { + public static Input create( + SdkBindingData> keyValues, SdkBindingData> searchKeys) { return new AutoValue_BatchLookUpTask_Input(keyValues, searchKeys); } } diff --git a/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java index a7d722e78..6d08b8d0f 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java @@ -40,7 +40,10 @@ public void expand(SdkWorkflowBuilder builder) { .apply( "decide", SdkConditions.when( - "when-empty", eq(name, ofString("")), new GreetTask(), GreetTask.Input.create(ofString("World"))) + "when-empty", + eq(name, ofString("")), + new GreetTask(), + GreetTask.Input.create(ofString("World"))) .otherwise("when-not-empty", new GreetTask(), GreetTask.Input.create(name))) .getOutputs() .greeting(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflow.java index 9d626b2d4..c438f08c9 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflow.java @@ -18,16 +18,14 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; -import java.util.List; -import java.util.Map; -import org.flyte.examples.BatchLookUpTask.Input; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class DynamicFibonacciWorkflow extends SdkWorkflow { +public class DynamicFibonacciWorkflow + extends SdkWorkflow { @AutoValue public abstract static class Input { public abstract SdkBindingData n(); @@ -36,8 +34,11 @@ public static DynamicFibonacciWorkflow.Input create(SdkBindingData n) { return new AutoValue_DynamicFibonacciWorkflow_Input(n); } } + public DynamicFibonacciWorkflow() { - super(JacksonSdkType.of(DynamicFibonacciWorkflow.Input.class), JacksonSdkType.of(DynamicFibonacciWorkflowTask.Output.class)); + super( + JacksonSdkType.of(DynamicFibonacciWorkflow.Input.class), + JacksonSdkType.of(DynamicFibonacciWorkflowTask.Output.class)); } @Override @@ -46,7 +47,10 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData fibOutput = builder - .apply("fibonacci", new DynamicFibonacciWorkflowTask(), DynamicFibonacciWorkflowTask.Input.create(n)) + .apply( + "fibonacci", + new DynamicFibonacciWorkflowTask(), + DynamicFibonacciWorkflowTask.Input.create(n)) .getOutputs() .output(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/FibonacciWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/FibonacciWorkflow.java index 45d121ee7..dc5cecbe1 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/FibonacciWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/FibonacciWorkflow.java @@ -25,10 +25,13 @@ import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class FibonacciWorkflow extends SdkWorkflow { +public class FibonacciWorkflow + extends SdkWorkflow { public FibonacciWorkflow() { - super(JacksonSdkType.of(FibonacciWorkflow.Input.class), JacksonSdkType.of(FibonacciWorkflow.Output.class)); + super( + JacksonSdkType.of(FibonacciWorkflow.Input.class), + JacksonSdkType.of(FibonacciWorkflow.Output.class)); } @Override @@ -36,12 +39,16 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData fib0 = builder.inputOfInteger("fib0", "Value for Fib0"); SdkBindingData fib1 = builder.inputOfInteger("fib1", "Value for Fib1"); - SdkNode apply = builder.apply("fib-2", new SumTask(), SumTask.SumInput.create(fib1, fib0)); + SdkNode apply = + builder.apply("fib-2", new SumTask(), SumTask.SumInput.create(fib1, fib0)); SumTask.SumOutput outputs = apply.getOutputs(); SdkBindingData fib2 = outputs.c(); - SdkBindingData fib3 = builder.apply("fib-3", new SumTask(), SumTask.SumInput.create(fib1, fib2)).getOutputs().c(); - SdkBindingData fib4 = builder.apply("fib-4", new SumTask(), SumTask.SumInput.create(fib2, fib3)).getOutputs().c(); - SdkBindingData fib5 = builder.apply("fib-5", new SumTask(), SumTask.SumInput.create(fib3, fib4)).getOutputs().c(); + SdkBindingData fib3 = + builder.apply("fib-3", new SumTask(), SumTask.SumInput.create(fib1, fib2)).getOutputs().c(); + SdkBindingData fib4 = + builder.apply("fib-4", new SumTask(), SumTask.SumInput.create(fib2, fib3)).getOutputs().c(); + SdkBindingData fib5 = + builder.apply("fib-5", new SumTask(), SumTask.SumInput.create(fib3, fib4)).getOutputs().c(); builder.output("fib5", fib5, "Value for Fib5"); } @@ -52,7 +59,8 @@ public abstract static class Input { public abstract SdkBindingData fib1(); - public static FibonacciWorkflow.Output create(SdkBindingData fib0, SdkBindingData fib1) { + public static FibonacciWorkflow.Output create( + SdkBindingData fib0, SdkBindingData fib1) { return new AutoValue_FibonacciWorkflow_Input(fib0, fib1); } } diff --git a/flytekit-examples/src/main/java/org/flyte/examples/PrintMessageTask.java b/flytekit-examples/src/main/java/org/flyte/examples/PrintMessageTask.java index 811da0385..66ca34c48 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/PrintMessageTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/PrintMessageTask.java @@ -20,7 +20,6 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.jackson.JacksonSdkType; diff --git a/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java b/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java index 1595693fa..55845567e 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java @@ -20,7 +20,6 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkRunnableTask.class) diff --git a/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java index 9895ebfb7..ebe04a19d 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java @@ -18,7 +18,6 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; -import org.flyte.examples.SumTask.SumInput; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; diff --git a/flytekit-examples/src/main/java/org/flyte/examples/WelcomeWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/WelcomeWorkflow.java index 0cbc914ab..61174ad60 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/WelcomeWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/WelcomeWorkflow.java @@ -18,7 +18,6 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; -import org.flyte.examples.SumTask.SumInput; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -38,7 +37,9 @@ public static WelcomeWorkflow.Input create(SdkBindingData name) { } public WelcomeWorkflow() { - super(JacksonSdkType.of(WelcomeWorkflow.Input.class), JacksonSdkType.of(AddQuestionTask.Output.class)); + super( + JacksonSdkType.of(WelcomeWorkflow.Input.class), + JacksonSdkType.of(AddQuestionTask.Output.class)); } @Override @@ -48,11 +49,17 @@ public void expand(SdkWorkflowBuilder builder) { // uses the workflow input as the task input of the GreetTask SdkBindingData greeting = - builder.apply("greet", new GreetTask(), GreetTask.Input.create(name)).getOutputs().greeting(); + builder + .apply("greet", new GreetTask(), GreetTask.Input.create(name)) + .getOutputs() + .greeting(); // uses the output of the GreetTask as the task input of the AddQuestionTask SdkBindingData greetingWithQuestion = - builder.apply("add-question", new AddQuestionTask(), AddQuestionTask.Input.create(greeting)).getOutputs().greeting(); + builder + .apply("add-question", new AddQuestionTask(), AddQuestionTask.Input.create(greeting)) + .getOutputs() + .greeting(); // uses the task output of the AddQuestionTask as the output of the workflow builder.output("greeting", greetingWithQuestion, "Welcome message"); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkAppliedTransform.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkAppliedTransform.java new file mode 100644 index 000000000..2811788bf --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkAppliedTransform.java @@ -0,0 +1,67 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import java.util.List; +import java.util.Set; +import javax.annotation.Nullable; + +class SdkAppliedTransform extends SdkTransform { + private final SdkTransform transform; + private final OriginalInputT appliedInputs; + + SdkAppliedTransform( + SdkTransform transform, @Nullable OriginalInputT appliedInputs) { + checkNotNull(transform, appliedInputs); + this.transform = transform; + this.appliedInputs = appliedInputs; + } + + static void checkNotNull(SdkTransform transform, @Nullable InputT inputs) { + Set variableNames = transform.getInputType().variableNames(); + if (inputs == null && !variableNames.isEmpty()) { + throw new IllegalArgumentException( + String.format( + "Null supplied as input for a transform with %s properties", variableNames)); + } + } + + @Override + public SdkType getInputType() { + return SdkTypes.nulls(); + } + + @Override + public SdkType getOutputType() { + return transform.getOutputType(); + } + + @Override + public String getName() { + return transform.getName(); + } + + @Override + public SdkNode apply( + SdkWorkflowBuilder builder, + String nodeId, + List upstreamNodeIds, + @Nullable SdkNodeMetadata metadata, + @Nullable Void inputs) { + return transform.apply(builder, nodeId, upstreamNodeIds, metadata, appliedInputs); + } +} diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java index 53032b86e..56677d221 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java @@ -16,7 +16,7 @@ */ package org.flyte.flytekit; -import static org.flyte.flytekit.MoreCollectors.toUnmodifiableList; +import static java.util.stream.Collectors.toUnmodifiableList; import static org.flyte.flytekit.MoreCollectors.toUnmodifiableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -26,6 +26,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BranchNode; import org.flyte.api.v1.IfElseBlock; @@ -89,15 +90,21 @@ public Node toIdl() { ifElseBlock = ifElseBlock.toBuilder().error(nodeError).build(); } + // inputs in var order for predictability + List inputs = + extraInputs.entrySet().stream() + .sorted(Entry.comparingByKey()) + .map(Entry::getValue) + .collect(toUnmodifiableList()); return Node.builder() .id(nodeId) .branchNode(BranchNode.builder().ifElse(ifElseBlock).build()) - .inputs(List.copyOf(extraInputs.values())) + .inputs(inputs) .upstreamNodeIds(upstreamNodeIds) .build(); } - static class Builder { + static class Builder { private final SdkWorkflowBuilder builder; private final SdkType outputType; @@ -113,7 +120,7 @@ static class Builder { } @CanIgnoreReturnValue - Builder addCase(SdkConditionCase case_) { + Builder addCase(SdkConditionCase case_) { SdkNode sdkNode = case_ .then() @@ -147,7 +154,7 @@ Builder addCase(SdkConditionCase case_) { } @CanIgnoreReturnValue - Builder addOtherwise(String name, SdkTransform otherwise) { + Builder addOtherwise(String name, SdkTransform otherwise) { if (elseNode != null) { throw new IllegalArgumentException("Duplicate otherwise clause"); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkCondition.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkCondition.java index d3fa842e2..c1908fb62 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkCondition.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkCondition.java @@ -20,17 +20,16 @@ import java.util.List; import javax.annotation.Nullable; -public class SdkCondition extends SdkTransform { - private final SdkType inputType; +public class SdkCondition extends SdkTransform { private final SdkType outputType; - private final List> cases; + private final List> cases; private final String otherwiseName; - private final SdkTransform otherwise; + private final SdkTransform otherwise; SdkCondition( - List> cases, + List> cases, String otherwiseName, - SdkTransform otherwise) { + SdkTransform otherwise) { if (cases.isEmpty()) { throw new IllegalArgumentException("Empty cases on SdkCondition"); } @@ -39,21 +38,26 @@ public class SdkCondition extends SdkTransform this.otherwise = otherwise; var firstCase = cases.get(0); - this.inputType = firstCase.then().getInputType(); this.outputType = firstCase.then().getOutputType(); } - public SdkCondition when( - String name, SdkBooleanExpression condition, SdkTransform then) { - - List> newCases = new ArrayList<>(cases); + public SdkCondition when( + String name, SdkBooleanExpression condition, SdkTransform then) { + var newCases = new ArrayList<>(cases); newCases.add(SdkConditionCase.create(name, condition, then)); return new SdkCondition<>(newCases, this.otherwiseName, this.otherwise); } - public SdkCondition otherwise( - String name, SdkTransform otherwise) { + public SdkCondition when( + String name, + SdkBooleanExpression condition, + SdkTransform then, + InputT inputs) { + return when(name, condition, new SdkAppliedTransform<>(then, inputs)); + } + + public SdkCondition otherwise(String name, SdkTransform otherwise) { if (this.otherwise != null) { throw new IllegalStateException("Can't set 'otherwise' because it's already set"); } @@ -61,9 +65,14 @@ public SdkCondition otherwise( return new SdkCondition<>(this.cases, name, otherwise); } + public SdkCondition otherwise( + String name, SdkTransform otherwise, InputT inputs) { + return otherwise(name, new SdkAppliedTransform<>(otherwise, inputs)); + } + @Override - public SdkType getInputType() { - return inputType; + public SdkType getInputType() { + return SdkTypes.nulls(); } @Override @@ -77,11 +86,10 @@ public SdkNode apply( String nodeId, List upstreamNodeIds, @Nullable SdkNodeMetadata metadata, - @Nullable InputT inputs) { - SdkBranchNode.Builder nodeBuilder = - new SdkBranchNode.Builder<>(builder, outputType); + @Nullable Void noInputs) { + SdkBranchNode.Builder nodeBuilder = new SdkBranchNode.Builder<>(builder, outputType); - for (SdkConditionCase case_ : cases) { + for (SdkConditionCase case_ : cases) { nodeBuilder.addCase(case_); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditionCase.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditionCase.java index dbbceb651..5459a4a1b 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditionCase.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditionCase.java @@ -19,15 +19,15 @@ import com.google.auto.value.AutoValue; @AutoValue -abstract class SdkConditionCase { +abstract class SdkConditionCase { abstract String name(); abstract SdkBooleanExpression condition(); - abstract SdkTransform then(); + abstract SdkTransform then(); - static SdkConditionCase create( - String name, SdkBooleanExpression condition, SdkTransform then) { + static SdkConditionCase create( + String name, SdkBooleanExpression condition, SdkTransform then) { return new AutoValue_SdkConditionCase<>(name, condition, then); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java index 97470eb7d..23003e31e 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java @@ -16,19 +16,30 @@ */ package org.flyte.flytekit; -import static java.util.Collections.singletonList; import static org.flyte.flytekit.SdkBooleanExpression.ofComparison; +import java.util.List; import org.flyte.api.v1.ComparisonExpression; public class SdkConditions { private SdkConditions() {} - public static SdkCondition when( - String name, SdkBooleanExpression condition, SdkTransform then) { - SdkConditionCase case_ = SdkConditionCase.create(name, condition, then); + public static SdkCondition when( + String name, SdkBooleanExpression condition, SdkTransform then) { + SdkConditionCase case_ = SdkConditionCase.create(name, condition, then); - return new SdkCondition<>(singletonList(case_), null, null); + return new SdkCondition<>(List.of(case_), null, null); + } + + public static SdkCondition when( + String name, + SdkBooleanExpression condition, + SdkTransform then, + InputT inputs) { + SdkConditionCase case_ = + SdkConditionCase.create(name, condition, new SdkAppliedTransform<>(then, inputs)); + + return new SdkCondition<>(List.of(case_), null, null); } public static SdkBooleanExpression eq(SdkBindingData left, SdkBindingData right) { diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkNodeNamePolicy.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkNodeNamePolicy.java index 4dee35f66..6200532c9 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkNodeNamePolicy.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkNodeNamePolicy.java @@ -17,7 +17,6 @@ package org.flyte.flytekit; import java.util.Locale; -import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Pattern; @@ -25,8 +24,8 @@ /** * Controls the default node id and node name policy when applying {@link SdkTransform} to {@link * SdkWorkflowBuilder}. When using {@link SdkWorkflowBuilder#apply(SdkTransform)} or {@link - * SdkWorkflowBuilder#apply(SdkTransform, Map)} then the node id used would be the one returned by - * {@link #nextNodeId()}. Also, if a node name haven't been set by the user, then {@link + * SdkWorkflowBuilder#apply(SdkTransform, Object)} then the node id used would be the one returned + * by {@link #nextNodeId()}. Also, if a node name haven't been set by the user, then {@link * #toNodeName(String)} would be used. */ class SdkNodeNamePolicy { diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java index 7d28ce37a..ad4a08f07 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java @@ -83,8 +83,10 @@ public String getNodeId() { public Node toIdl() { TaskNode taskNode = TaskNode.builder().referenceId(taskId).build(); + // inputs in var order for predictability List bindings = inputs.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) .map(x -> toBinding(x.getKey(), x.getValue())) .collect(toUnmodifiableList()); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java index 72e3dd517..929eaee9e 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java @@ -17,6 +17,7 @@ package org.flyte.flytekit; import java.util.Map; +import java.util.Set; import org.flyte.api.v1.Literal; import org.flyte.api.v1.Variable; @@ -30,5 +31,9 @@ public abstract class SdkType { public abstract Map getVariableMap(); + public Set variableNames() { + return Set.copyOf(getVariableMap().keySet()); + } + public abstract Map> toSdkBindingMap(T value); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java index 06308a964..eafd299ab 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java @@ -22,10 +22,13 @@ /** A utility class for creating {@link SdkType} objects for different types. */ public class SdkTypes { + + private static final VoidSdkType VOID_SDK_TYPE = new VoidSdkType(); + private SdkTypes() {} public static SdkType nulls() { - return new VoidSdkType(); + return VOID_SDK_TYPE; } private static class VoidSdkType extends SdkType { @@ -54,10 +57,5 @@ public Map getVariableMap() { public Map> toSdkBindingMap(Void value) { return Map.of(); } - - @Override - public Map> toSdkBindingMap(Void value) { - return Collections.emptyMap(); - } } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java index a6feff443..91bb1a69f 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java @@ -81,6 +81,7 @@ protected SdkNode applyInternal( SdkTransform transform, List upstreamNodeIds, @Nullable InputT inputs) { + SdkAppliedTransform.checkNotNull(transform, inputs); String actualNodeId = Objects.requireNonNullElseGet(nodeId, sdkNodeNamePolicy::nextNodeId); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowNode.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowNode.java index cb38f1eb9..151f1094c 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowNode.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowNode.java @@ -71,8 +71,10 @@ public String getNodeId() { @Override public Node toIdl() { + // inputs in var order for predictability List inputBindings = this.inputBindings.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) .map(x -> toBinding(x.getKey(), x.getValue())) .collect(toUnmodifiableList()); diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java index e0f85d7d9..827b67bbd 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java @@ -135,7 +135,14 @@ void testTimes4WorkflowIdl() { .nodes(List.of(node0, node1)) .build(); - assertEquals(expected, builder.toIdlTemplate()); + WorkflowTemplate actual = builder.toIdlTemplate(); + assertEquals(expected.interface_(), actual.interface_()); + assertEquals(expected.metadata(), actual.metadata()); + assertEquals(expected.outputs(), actual.outputs()); + assertEquals(expected.nodes().get(0), actual.nodes().get(0)); + assertEquals(expected.nodes().get(1), actual.nodes().get(1)); + assertEquals(expected.nodes(), actual.nodes()); + assertEquals(expected, actual); } @Test @@ -471,8 +478,11 @@ public void expand(SdkWorkflowBuilder builder) { SdkNode out = builder.apply( "square", - SdkConditions.when("neq", SdkConditions.neq(in, two), new MultiplicationTask()), - TestPairIntegerInput.create(in, two)); + SdkConditions.when( + "neq", + SdkConditions.neq(in, two), + new MultiplicationTask(), + TestPairIntegerInput.create(in, two))); builder.output("o", out.getOutputs().o()); } diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index 9fb59e140..e466da62a 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -17,7 +17,7 @@ package org.flyte.flytekitscala import java.time.{Duration, Instant} -import java.{util, util => ju} +import java.{util => ju} import magnolia.{CaseClass, Magnolia, Param, SealedTrait} import org.flyte.api.v1._ import org.flyte.flytekit.{SdkType, SdkBindingData => SdkJavaBindinigData} @@ -311,7 +311,7 @@ private object SdkUnitType extends SdkScalaProductType[Unit] { def promiseFor(nodeId: String): Unit = () - def toSdkBindingMap(value: Unit): util.Map[String, SdkJavaBindinigData[_]] = + def toSdkBindingMap(value: Unit): ju.Map[String, SdkJavaBindinigData[_]] = ju.Map.of() } diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaWorkflow.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaWorkflow.scala index d5a33f592..048104f7e 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaWorkflow.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaWorkflow.scala @@ -29,8 +29,10 @@ import org.flyte.flytekit.{ import java.time.{Duration, Instant} import scala.collection.JavaConverters._ -abstract class SdkScalaWorkflow[T](outputType: SdkType[T]) - extends SdkWorkflow[T](outputType) { +abstract class SdkScalaWorkflow[InputT, OutputT]( + inputType: SdkType[InputT], + outputType: SdkType[OutputT] +) extends SdkWorkflow[InputT, OutputT](inputType, outputType) { final override def expand(builder: SdkWorkflowBuilder): Unit = { expand(new SdkScalaWorkflowBuilder(builder)) } @@ -104,26 +106,33 @@ class SdkScalaWorkflowBuilder(builder: SdkWorkflowBuilder) { def getOutputDescription(name: String): String = builder.getOutputDescription(name) - def output(name: String, value: SdkJavaBindingData[_], help: String = "") = + def output( + name: String, + value: SdkJavaBindingData[_], + help: String = "" + ): Unit = builder.output(name, value, help) def toIdlTemplate: WorkflowTemplate = builder.toIdlTemplate - def apply[T](nodeId: String, transform: SdkTransform[T]): SdkNode[T] = - builder.apply(nodeId, transform) + def apply[OutputT]( + nodeId: String, + transform: SdkTransform[Unit, OutputT] + ): SdkNode[OutputT] = + builder.apply(nodeId, transform, ()) - def apply[T]( + def apply[InputT, OutputT]( nodeId: String, - transform: SdkTransform[T], - inputs: Map[String, SdkJavaBindingData[_]] - ): SdkNode[T] = builder.apply(nodeId, transform, inputs.asJava) + transform: SdkTransform[InputT, OutputT], + inputs: InputT + ): SdkNode[OutputT] = builder.apply(nodeId, transform, inputs) - def apply[T](transform: SdkTransform[T]): SdkNode[T] = - builder.apply(transform) + def apply[OutputT](transform: SdkTransform[Unit, OutputT]): SdkNode[OutputT] = + builder.apply(transform, ()) - def apply[T]( - transform: SdkTransform[T], - inputs: Map[String, SdkJavaBindingData[_]] - ): SdkNode[T] = builder.apply(transform, inputs.asJava) + def apply[InputT, OutputT]( + transform: SdkTransform[InputT, OutputT], + inputs: InputT + ): SdkNode[OutputT] = builder.apply(transform, inputs) }