diff --git a/flyteidl-protos/pom.xml b/flyteidl-protos/pom.xml index e4a7062bd..ee2e00475 100644 --- a/flyteidl-protos/pom.xml +++ b/flyteidl-protos/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT flyteidl-protos diff --git a/flytekit-api/pom.xml b/flytekit-api/pom.xml index 411e824fc..0ecf24fe2 100644 --- a/flytekit-api/pom.xml +++ b/flytekit-api/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT flytekit-api diff --git a/flytekit-examples-scala/pom.xml b/flytekit-examples-scala/pom.xml index e1ff978d1..d73556786 100644 --- a/flytekit-examples-scala/pom.xml +++ b/flytekit-examples-scala/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT flytekit-examples-scala diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/AddQuestionTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/AddQuestionTask.scala index ea2b87d1f..96f6cad86 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/AddQuestionTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/AddQuestionTask.scala @@ -46,18 +46,3 @@ class AddQuestionTask override def run(input: AddQuestionTaskInput): AddQuestionTaskOutput = AddQuestionTaskOutput(ofString(s"${input.greeting.get} How are you?")) } - -object AddQuestionTask { - - /** Binds input data to this task - * - * @param greeting - * the input greeting message - * @return - * a transformed instance of this class with input data - */ - def apply( - greeting: SdkBindingData[String] - ): SdkTransform[AddQuestionTaskOutput] = - new AddQuestionTask().withInput("greeting", greeting) -} diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflow.scala index a2d26aa30..a468cd42f 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflow.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflow.scala @@ -23,9 +23,14 @@ import org.flyte.flytekitscala.{ SdkScalaWorkflowBuilder } +case class DynamicFibonacciWorkflowInput(n: SdkBindingData[Long]) case class DynamicFibonacciWorkflowOutput(output: SdkBindingData[Long]) class DynamicFibonacciWorkflow - extends SdkScalaWorkflow[DynamicFibonacciWorkflowOutput]( + extends SdkScalaWorkflow[ + DynamicFibonacciWorkflowInput, + DynamicFibonacciWorkflowOutput + ]( + SdkScalaType[DynamicFibonacciWorkflowInput], SdkScalaType[DynamicFibonacciWorkflowOutput] ) { @@ -34,7 +39,8 @@ class DynamicFibonacciWorkflow val fibonacci = builder.apply( "fibonacci", - new DynamicFibonacciWorkflowTask().withInput("n", n) + new DynamicFibonacciWorkflowTask(), + DynamicFibonacciWorkflowTaskInput(n) ) builder.output("output", fibonacci.getOutputs.output) diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflowTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflowTask.scala index bda8db902..3a1e9a4cc 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflowTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/DynamicFibonacciWorkflowTask.scala @@ -53,7 +53,11 @@ class DynamicFibonacciWorkflowTask else fib( n + 1, - builder(s"fib-${n + 1}", SumTask(value, prev)).getOutputs.c, + builder( + s"fib-${n + 1}", + new SumTask(), + SumTaskInput(value, prev) + ).getOutputs.c, value ) } diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciWorkflow.scala index 637603c8e..b4a405a15 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciWorkflow.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciWorkflow.scala @@ -16,17 +16,22 @@ */ package org.flyte.examples.flytekitscala -import org.flyte.flytekit.{SdkBindingData, SdkWorkflow, SdkWorkflowBuilder} +import org.flyte.flytekit.SdkBindingData import org.flyte.flytekitscala.{ SdkScalaWorkflowBuilder, SdkScalaType, SdkScalaWorkflow } +case class FibonacciWorkflowInput( + fib0: SdkBindingData[Long], + fib1: SdkBindingData[Long] +) case class FibonacciWorkflowOutput(fib5: SdkBindingData[Long]) class FibonacciWorkflow - extends SdkScalaWorkflow[FibonacciWorkflowOutput]( + extends SdkScalaWorkflow[FibonacciWorkflowInput, FibonacciWorkflowOutput]( + SdkScalaType[FibonacciWorkflowInput], SdkScalaType[FibonacciWorkflowOutput] ) { @@ -34,10 +39,22 @@ class FibonacciWorkflow val fib0 = builder.inputOfInteger("fib0", "Value for Fib0") val fib1 = builder.inputOfInteger("fib1", "Value for Fib1") - val fib2 = builder.apply("fib-2", SumTask(fib0, fib1)).getOutputs.c - val fib3 = builder.apply("fib-3", SumTask(fib1, fib2)).getOutputs.c - val fib4 = builder.apply("fib-4", SumTask(fib2, fib3)).getOutputs.c - val fib5 = builder.apply("fib-5", SumTask(fib3, fib4)).getOutputs.c + val fib2 = builder + .apply("fib-2", new SumTask(), SumTaskInput(fib0, fib1)) + .getOutputs + .c + val fib3 = builder + .apply("fib-3", new SumTask(), SumTaskInput(fib1, fib2)) + .getOutputs + .c + val fib4 = builder + .apply("fib-4", new SumTask(), SumTaskInput(fib2, fib3)) + .getOutputs + .c + val fib5 = builder + .apply("fib-5", new SumTask(), SumTaskInput(fib3, fib4)) + .getOutputs + .c builder.output("fib5", fib5, "Value for Fib5") } diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/GreetTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/GreetTask.scala index b2c72ab6a..58ff7af5d 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/GreetTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/GreetTask.scala @@ -46,16 +46,3 @@ class GreetTask override def run(input: GreetTaskInput): GreetTaskOutput = GreetTaskOutput(ofString(s"Welcome, ${input.name.get()}!")) } - -object GreetTask { - - /** Binds input data to this task - * - * @param name - * the input name - * @return - * a transformed instance of this class with input data - */ - def apply(name: SdkBindingData[String]): SdkTransform[GreetTaskOutput] = - new GreetTask().withInput("name", name) -} diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/SumTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/SumTask.scala index 401cad51d..96d536f7e 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/SumTask.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/SumTask.scala @@ -43,11 +43,3 @@ class SumTask override def isCacheSerializable: Boolean = true } - -object SumTask { - def apply( - a: SdkBindingData[Long], - b: SdkBindingData[Long] - ): SdkTransform[SumTaskOutput] = - new SumTask().withInput("a", a).withInput("b", b) -} diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WelcomeWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WelcomeWorkflow.scala index 8c9e8d135..a75f63e1f 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WelcomeWorkflow.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/WelcomeWorkflow.scala @@ -16,7 +16,7 @@ */ package org.flyte.examples.flytekitscala -import org.flyte.flytekit.{SdkBindingData, SdkWorkflow, SdkWorkflowBuilder} +import org.flyte.flytekit.SdkBindingData import org.flyte.flytekitscala.{ SdkScalaType, SdkScalaWorkflow, @@ -49,10 +49,12 @@ import org.flyte.flytekitscala.{ * | output: greeting(string) | */ +case class WelcomeWorkflowInput(name: SdkBindingData[String]) case class WelcomeWorkflowOutput(greeting: SdkBindingData[String]) class WelcomeWorkflow - extends SdkScalaWorkflow[WelcomeWorkflowOutput]( + extends SdkScalaWorkflow[WelcomeWorkflowInput, WelcomeWorkflowOutput]( + SdkScalaType[WelcomeWorkflowInput], SdkScalaType[WelcomeWorkflowOutput] ) { @@ -61,11 +63,18 @@ class WelcomeWorkflow val name = builder.inputOfString("name", "The name for the welcome message") // uses the workflow input as the task input of the GreetTask - val greeting = builder.apply("greet", GreetTask(name)).getOutputs.greeting + val greeting = builder + .apply("greet", new GreetTask(), GreetTaskInput(name)) + .getOutputs + .greeting // uses the output of the GreetTask as the task input of the AddQuestionTask val greetingWithQuestion = builder - .apply("add-question", AddQuestionTask(greeting)) + .apply( + "add-question", + new AddQuestionTask(), + AddQuestionTaskInput(greeting) + ) .getOutputs .greeting diff --git a/flytekit-examples/pom.xml b/flytekit-examples/pom.xml index 5f4ee4eb0..c82b45e91 100644 --- a/flytekit-examples/pom.xml +++ b/flytekit-examples/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT flytekit-examples diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java b/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java index 81adc70c1..9c3803901 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AddQuestionTask.java @@ -20,7 +20,6 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.jackson.JacksonSdkType; /** @@ -34,16 +33,6 @@ public AddQuestionTask() { super(JacksonSdkType.of(Input.class), JacksonSdkType.of(Output.class)); } - /** - * Binds input data to this task. - * - * @param greeting the input greeting message - * @return a transformed instance of this class with input data - */ - public static SdkTransform of(SdkBindingData greeting) { - return new AddQuestionTask().withInput("greeting", greeting); - } - /** * Generate an immutable value class that represents {@link AddQuestionTask}'s input, which is a * String. @@ -51,6 +40,10 @@ public static SdkTransform of(SdkBindingData gre @AutoValue public abstract static class Input { public abstract SdkBindingData greeting(); + + public static Input create(SdkBindingData greeting) { + return new AutoValue_AddQuestionTask_Input(greeting); + } } /** diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java index 29038c835..46849fa7a 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java @@ -16,9 +16,6 @@ */ package org.flyte.examples; -import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; - import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import java.time.Duration; @@ -27,7 +24,6 @@ import java.util.Map; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkRunnableTask.class) @@ -38,30 +34,6 @@ public AllInputsTask() { JacksonSdkType.of(AutoAllInputsInput.class), JacksonSdkType.of(AutoAllInputsOutput.class)); } - public static SdkTransform of( - SdkBindingData i, - SdkBindingData f, - SdkBindingData s, - SdkBindingData b, - SdkBindingData t, - SdkBindingData d, - SdkBindingData> l, - SdkBindingData> m, - SdkBindingData> emptyList, - SdkBindingData> emptyMap) { - return new AllInputsTask() - .withInput("i", i) - .withInput("f", f) - .withInput("s", s) - .withInput("b", b) - .withInput("t", t) - .withInput("d", d) - .withInput("l", l) - .withInput("m", m) - .withInput("emptyList", emptyList) - .withInput("emptyMap", emptyMap); - } - @AutoValue public abstract static class AutoAllInputsInput { public abstract SdkBindingData i(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java index 8d9f87b4e..51bd1d53f 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java @@ -27,15 +27,17 @@ import org.flyte.examples.AllInputsTask.AutoAllInputsOutput; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkNode; +import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class AllInputsWorkflow extends SdkWorkflow { +public class AllInputsWorkflow + extends SdkWorkflow { public AllInputsWorkflow() { - super(JacksonSdkType.of(AllInputsWorkflow.AllInputsWorkflowOutput.class)); + super(SdkTypes.nulls(), JacksonSdkType.of(AllInputsWorkflow.AllInputsWorkflowOutput.class)); } @Override @@ -46,7 +48,8 @@ public void expand(SdkWorkflowBuilder builder) { SdkNode apply = builder.apply( "all-inputs", - AllInputsTask.of( + new AllInputsTask(), + AllInputsTask.AutoAllInputsInput.create( SdkBindingData.ofInteger(1L), SdkBindingData.ofFloat(2), SdkBindingData.ofString("test"), @@ -57,6 +60,7 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData.ofStringMap(Map.of("test", "test")), SdkBindingData.ofStringCollection(Collections.emptyList()), SdkBindingData.ofIntegerMap(Collections.emptyMap()))); + AllInputsTask.AutoAllInputsOutput outputs = apply.getOutputs(); builder.output("i", outputs.i(), "Integer value"); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java b/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java index 23b0953b7..bd0889a75 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/BatchLookUpTask.java @@ -54,6 +54,11 @@ public abstract static class Input { public abstract SdkBindingData> keyValues(); public abstract SdkBindingData> searchKeys(); + + public static Input create( + SdkBindingData> keyValues, SdkBindingData> searchKeys) { + return new AutoValue_BatchLookUpTask_Input(keyValues, searchKeys); + } } @AutoValue diff --git a/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java index 50359c01f..6d08b8d0f 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/ConditionalGreetingWorkflow.java @@ -27,9 +27,9 @@ import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class ConditionalGreetingWorkflow extends SdkWorkflow { +public class ConditionalGreetingWorkflow extends SdkWorkflow { public ConditionalGreetingWorkflow() { - super(JacksonSdkType.of(GreetTask.Output.class)); + super(JacksonSdkType.of(GreetTask.Input.class), JacksonSdkType.of(GreetTask.Output.class)); } @Override @@ -40,8 +40,11 @@ public void expand(SdkWorkflowBuilder builder) { .apply( "decide", SdkConditions.when( - "when-empty", eq(name, ofString("")), GreetTask.of(ofString("World"))) - .otherwise("when-not-empty", GreetTask.of(name))) + "when-empty", + eq(name, ofString("")), + new GreetTask(), + GreetTask.Input.create(ofString("World"))) + .otherwise("when-not-empty", new GreetTask(), GreetTask.Input.create(name))) .getOutputs() .greeting(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/ContainerWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/ContainerWorkflow.java index 85491f768..15885a467 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/ContainerWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/ContainerWorkflow.java @@ -23,10 +23,10 @@ /** Example workflow that takes a name and outputs a welcome message. */ @AutoService(SdkWorkflow.class) -public class ContainerWorkflow extends SdkWorkflow { +public class ContainerWorkflow extends SdkWorkflow { public ContainerWorkflow() { - super(SdkTypes.nulls()); + super(SdkTypes.nulls(), SdkTypes.nulls()); } @Override diff --git a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflow.java index 6e6088055..c438f08c9 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflow.java @@ -17,15 +17,28 @@ package org.flyte.examples; import com.google.auto.service.AutoService; +import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class DynamicFibonacciWorkflow extends SdkWorkflow { +public class DynamicFibonacciWorkflow + extends SdkWorkflow { + @AutoValue + public abstract static class Input { + public abstract SdkBindingData n(); + + public static DynamicFibonacciWorkflow.Input create(SdkBindingData n) { + return new AutoValue_DynamicFibonacciWorkflow_Input(n); + } + } + public DynamicFibonacciWorkflow() { - super(JacksonSdkType.of(DynamicFibonacciWorkflowTask.Output.class)); + super( + JacksonSdkType.of(DynamicFibonacciWorkflow.Input.class), + JacksonSdkType.of(DynamicFibonacciWorkflowTask.Output.class)); } @Override @@ -34,7 +47,10 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData fibOutput = builder - .apply("fibonacci", new DynamicFibonacciWorkflowTask().withInput("n", n)) + .apply( + "fibonacci", + new DynamicFibonacciWorkflowTask(), + DynamicFibonacciWorkflowTask.Input.create(n)) .getOutputs() .output(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java index b06f7c238..baf4d22e0 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java @@ -19,6 +19,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import com.google.errorprone.annotations.Var; +import org.flyte.examples.SumTask.SumInput; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkDynamicWorkflowTask; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -35,11 +36,19 @@ public DynamicFibonacciWorkflowTask() { @AutoValue abstract static class Input { public abstract SdkBindingData n(); + + public static DynamicFibonacciWorkflowTask.Input create(SdkBindingData n) { + return new AutoValue_DynamicFibonacciWorkflowTask_Input(n); + } } @AutoValue abstract static class Output { public abstract SdkBindingData output(); + + public static DynamicFibonacciWorkflowTask.Output create(SdkBindingData output) { + return new AutoValue_DynamicFibonacciWorkflowTask_Output(output); + } } @Override @@ -53,7 +62,7 @@ public void run(SdkWorkflowBuilder builder, Input input) { @Var SdkBindingData value = SdkBindingData.ofInteger(1); for (int i = 2; i <= input.n().get(); i++) { SdkBindingData next = - builder.apply("fib-" + i, SumTask.of(value, prev)).getOutputs().c(); + builder.apply("fib-" + i, new SumTask(), SumInput.create(value, prev)).getOutputs().c(); prev = value; value = next; } diff --git a/flytekit-examples/src/main/java/org/flyte/examples/FibonacciWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/FibonacciWorkflow.java index 715bc8e53..4c1e00d03 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/FibonacciWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/FibonacciWorkflow.java @@ -25,10 +25,13 @@ import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class FibonacciWorkflow extends SdkWorkflow { +public class FibonacciWorkflow + extends SdkWorkflow { public FibonacciWorkflow() { - super(JacksonSdkType.of(FibonacciWorkflow.Output.class)); + super( + JacksonSdkType.of(FibonacciWorkflow.Input.class), + JacksonSdkType.of(FibonacciWorkflow.Output.class)); } @Override @@ -36,16 +39,32 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData fib0 = builder.inputOfInteger("fib0", "Value for Fib0"); SdkBindingData fib1 = builder.inputOfInteger("fib1", "Value for Fib1"); - SdkNode apply = builder.apply("fib-2", SumTask.of(fib0, fib1)); + SdkNode apply = + builder.apply("fib-2", new SumTask(), SumTask.SumInput.create(fib1, fib0)); SumTask.SumOutput outputs = apply.getOutputs(); SdkBindingData fib2 = outputs.c(); - SdkBindingData fib3 = builder.apply("fib-3", SumTask.of(fib1, fib2)).getOutputs().c(); - SdkBindingData fib4 = builder.apply("fib-4", SumTask.of(fib2, fib3)).getOutputs().c(); - SdkBindingData fib5 = builder.apply("fib-5", SumTask.of(fib3, fib4)).getOutputs().c(); + SdkBindingData fib3 = + builder.apply("fib-3", new SumTask(), SumTask.SumInput.create(fib1, fib2)).getOutputs().c(); + SdkBindingData fib4 = + builder.apply("fib-4", new SumTask(), SumTask.SumInput.create(fib2, fib3)).getOutputs().c(); + SdkBindingData fib5 = + builder.apply("fib-5", new SumTask(), SumTask.SumInput.create(fib3, fib4)).getOutputs().c(); builder.output("fib5", fib5, "Value for Fib5"); } + @AutoValue + public abstract static class Input { + public abstract SdkBindingData fib0(); + + public abstract SdkBindingData fib1(); + + public static FibonacciWorkflow.Input create( + SdkBindingData fib0, SdkBindingData fib1) { + return new AutoValue_FibonacciWorkflow_Input(fib0, fib1); + } + } + @AutoValue public abstract static class Output { public abstract SdkBindingData fib5(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/GreetTask.java b/flytekit-examples/src/main/java/org/flyte/examples/GreetTask.java index 4f34bd052..5ccc5c9d3 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/GreetTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/GreetTask.java @@ -20,7 +20,6 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.jackson.JacksonSdkType; /** Example Flyte task that takes a name as the input and outputs a simple greeting message. */ @@ -31,22 +30,16 @@ public GreetTask() { super(JacksonSdkType.of(Input.class), JacksonSdkType.of(Output.class)); } - /** - * Binds input data to this task. - * - * @param name the input name - * @return a transformed instance of this class with input data - */ - public static SdkTransform of(SdkBindingData name) { - return new GreetTask().withInput("name", name); - } - /** * Generate an immutable value class that represents {@link GreetTask}'s input, which is a String. */ @AutoValue public abstract static class Input { public abstract SdkBindingData name(); + + public static Input create(SdkBindingData greeting) { + return new AutoValue_GreetTask_Input(greeting); + } } /** diff --git a/flytekit-examples/src/main/java/org/flyte/examples/NodeMetadataExampleWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/NodeMetadataExampleWorkflow.java index c2b0cc02b..d1cf4a6d3 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/NodeMetadataExampleWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/NodeMetadataExampleWorkflow.java @@ -20,12 +20,14 @@ import com.google.auto.value.AutoValue; import java.time.Duration; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class NodeMetadataExampleWorkflow extends SdkWorkflow { +public class NodeMetadataExampleWorkflow + extends SdkWorkflow { @AutoValue public abstract static class Output { @@ -43,7 +45,7 @@ public static NodeMetadataExampleWorkflow.Output create(SdkBindingData c } public NodeMetadataExampleWorkflow() { - super(JacksonSdkType.of(NodeMetadataExampleWorkflow.Output.class)); + super(SdkTypes.nulls(), JacksonSdkType.of(NodeMetadataExampleWorkflow.Output.class)); } @Override @@ -55,9 +57,10 @@ public void expand(SdkWorkflowBuilder builder) { builder .apply( "sum-a-b", - SumTask.of(a, b) + new SumTask() .withNameOverride("sum a+b") - .withTimeoutOverride(Duration.ofMinutes(15))) + .withTimeoutOverride(Duration.ofMinutes(15)), + SumTask.SumInput.create(a, b)) .getOutputs() .c(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/PhoneBookWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/PhoneBookWorkflow.java index 96ed33071..ab8dbe78e 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/PhoneBookWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/PhoneBookWorkflow.java @@ -23,12 +23,13 @@ import java.util.List; import java.util.Map; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class PhoneBookWorkflow extends SdkWorkflow { +public class PhoneBookWorkflow extends SdkWorkflow { private static final List NAMES = Arrays.asList("frodo", "bilbo"); private static final Map PHONE_BOOK = new HashMap<>(); @@ -55,7 +56,7 @@ public static PhoneBookWorkflow.Output create(SdkBindingData> phone } public PhoneBookWorkflow() { - super(JacksonSdkType.of(PhoneBookWorkflow.Output.class)); + super(SdkTypes.nulls(), JacksonSdkType.of(PhoneBookWorkflow.Output.class)); } @Override @@ -68,9 +69,8 @@ public void expand(SdkWorkflowBuilder builder) { builder .apply( "search", - new BatchLookUpTask() - .withInput("keyValues", phoneBook) - .withInput("searchKeys", searchKeys)) + new BatchLookUpTask(), + BatchLookUpTask.Input.create(phoneBook, searchKeys)) .getOutputs() .values(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/PrintMessageTask.java b/flytekit-examples/src/main/java/org/flyte/examples/PrintMessageTask.java index d8e986755..66ca34c48 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/PrintMessageTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/PrintMessageTask.java @@ -20,7 +20,6 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -32,10 +31,6 @@ public PrintMessageTask() { super(JacksonSdkType.of(Input.class), SdkTypes.nulls()); } - public static SdkTransform of(SdkBindingData message) { - return new PrintMessageTask().withInput("message", message); - } - /** Input for {@link PrintMessageTask}. */ @AutoValue public abstract static class Input { diff --git a/flytekit-examples/src/main/java/org/flyte/examples/RemoteLaunchPlanExample.java b/flytekit-examples/src/main/java/org/flyte/examples/RemoteLaunchPlanExample.java index 627f9bd37..aec8d4004 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/RemoteLaunchPlanExample.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/RemoteLaunchPlanExample.java @@ -17,6 +17,7 @@ package org.flyte.examples; import com.google.auto.value.AutoValue; +import org.flyte.examples.RemoteLaunchPlanExample.Input; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRemoteLaunchPlan; import org.flyte.flytekit.SdkTypes; @@ -29,17 +30,17 @@ // launchplan to be registered already. // The order that we register objects in jflyte is: task, workflows and launchplans // @AutoService(SdkWorkflow.class) -public class RemoteLaunchPlanExample extends SdkWorkflow { +public class RemoteLaunchPlanExample extends SdkWorkflow { public RemoteLaunchPlanExample() { - super(SdkTypes.nulls()); + super(JacksonSdkType.of(Input.class), SdkTypes.nulls()); } @Override public void expand(SdkWorkflowBuilder builder) { SdkBindingData fib0 = builder.inputOfInteger("fib0"); SdkBindingData fib1 = builder.inputOfInteger("fib1"); - builder.apply("remote-launch-plan", create().withInput("fib0", fib0).withInput("fib1", fib1)); + builder.apply("remote-launch-plan", create(), Input.create(fib0, fib1)); } public static SdkRemoteLaunchPlan create() { diff --git a/flytekit-examples/src/main/java/org/flyte/examples/SimpleStructTask.java b/flytekit-examples/src/main/java/org/flyte/examples/SimpleStructTask.java index 206bc3866..7b1a82d03 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/SimpleStructTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/SimpleStructTask.java @@ -18,9 +18,7 @@ import com.google.auto.value.AutoValue; import org.flyte.api.v1.Struct; -import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.jackson.JacksonSdkType; /** Example Flyte task that takes a name as the input and outputs a simple greeting message. */ @@ -31,10 +29,6 @@ public SimpleStructTask() { super(JacksonSdkType.of(Input.class), JacksonSdkType.of(Output.class)); } - public static SdkTransform of(SdkBindingData struct) { - return new SimpleStructTask().withInput("struct", struct); - } - @AutoValue public abstract static class Input { public abstract Struct in(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/SubWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/SubWorkflow.java index b03fe7d67..74402ef67 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/SubWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/SubWorkflow.java @@ -24,17 +24,18 @@ import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class SubWorkflow extends SdkWorkflow { +public class SubWorkflow extends SdkWorkflow { public SubWorkflow() { - super(JacksonSdkType.of(SubWorkflow.Output.class)); + super(JacksonSdkType.of(SubWorkflow.Input.class), JacksonSdkType.of(SubWorkflow.Output.class)); } @Override public void expand(SdkWorkflowBuilder builder) { SdkBindingData left = builder.inputOfInteger("left"); SdkBindingData right = builder.inputOfInteger("right"); - SdkBindingData result = builder.apply("sum", SumTask.of(left, right)).getOutputs().c(); + SdkBindingData result = + builder.apply("sum", new SumTask(), SumTask.SumInput.create(left, right)).getOutputs().c(); builder.output("result", result); } diff --git a/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java b/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java index d838415b7..55845567e 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/SumTask.java @@ -20,7 +20,6 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkRunnableTask.class) @@ -29,10 +28,6 @@ public SumTask() { super(JacksonSdkType.of(SumInput.class), JacksonSdkType.of(SumOutput.class)); } - public static SdkTransform of(SdkBindingData a, SdkBindingData b) { - return new SumTask().withInput("a", a).withInput("b", b); - } - @AutoValue public abstract static class SumInput { public abstract SdkBindingData a(); diff --git a/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java index 68f812b61..ebe04a19d 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/UberWorkflow.java @@ -17,16 +17,36 @@ package org.flyte.examples; import com.google.auto.service.AutoService; +import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class UberWorkflow extends SdkWorkflow { +public class UberWorkflow extends SdkWorkflow { + + @AutoValue + public abstract static class Input { + public abstract SdkBindingData a(); + + public abstract SdkBindingData b(); + + public abstract SdkBindingData c(); + + public abstract SdkBindingData d(); + + public static UberWorkflow.Input create( + SdkBindingData a, + SdkBindingData b, + SdkBindingData c, + SdkBindingData d) { + return new AutoValue_UberWorkflow_Input(a, b, c, d); + } + } public UberWorkflow() { - super(JacksonSdkType.of(SubWorkflow.Output.class)); + super(JacksonSdkType.of(UberWorkflow.Input.class), JacksonSdkType.of(SubWorkflow.Output.class)); } @Override @@ -37,15 +57,16 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData d = builder.inputOfInteger("d"); SdkBindingData ab = builder - .apply("sub-1", new SubWorkflow().withInput("left", a).withInput("right", b)) + .apply("sub-1", new SubWorkflow(), SubWorkflow.Input.create(a, b)) .getOutputs() .result(); SdkBindingData abc = builder - .apply("sub-2", new SubWorkflow().withInput("left", ab).withInput("right", c)) + .apply("sub-2", new SubWorkflow(), SubWorkflow.Input.create(ab, c)) .getOutputs() .result(); - SdkBindingData abcd = builder.apply("post-sum", SumTask.of(abc, d)).getOutputs().c(); + SdkBindingData abcd = + builder.apply("post-sum", new SumTask(), SumTask.SumInput.create(abc, d)).getOutputs().c(); builder.output("result", abcd); } } diff --git a/flytekit-examples/src/main/java/org/flyte/examples/WelcomeWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/WelcomeWorkflow.java index 3762260a5..61174ad60 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/WelcomeWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/WelcomeWorkflow.java @@ -17,6 +17,7 @@ package org.flyte.examples; import com.google.auto.service.AutoService; +import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -24,10 +25,21 @@ /** Example workflow that takes a name and outputs a welcome message. */ @AutoService(SdkWorkflow.class) -public class WelcomeWorkflow extends SdkWorkflow { +public class WelcomeWorkflow extends SdkWorkflow { + + @AutoValue + public abstract static class Input { + public abstract SdkBindingData name(); + + public static WelcomeWorkflow.Input create(SdkBindingData name) { + return new AutoValue_WelcomeWorkflow_Input(name); + } + } public WelcomeWorkflow() { - super(JacksonSdkType.of(AddQuestionTask.Output.class)); + super( + JacksonSdkType.of(WelcomeWorkflow.Input.class), + JacksonSdkType.of(AddQuestionTask.Output.class)); } @Override @@ -37,11 +49,17 @@ public void expand(SdkWorkflowBuilder builder) { // uses the workflow input as the task input of the GreetTask SdkBindingData greeting = - builder.apply("greet", GreetTask.of(name)).getOutputs().greeting(); + builder + .apply("greet", new GreetTask(), GreetTask.Input.create(name)) + .getOutputs() + .greeting(); // uses the output of the GreetTask as the task input of the AddQuestionTask SdkBindingData greetingWithQuestion = - builder.apply("add-question", AddQuestionTask.of(greeting)).getOutputs().greeting(); + builder + .apply("add-question", new AddQuestionTask(), AddQuestionTask.Input.create(greeting)) + .getOutputs() + .greeting(); // uses the task output of the AddQuestionTask as the output of the workflow builder.output("greeting", greetingWithQuestion, "Welcome message"); diff --git a/flytekit-jackson/pom.xml b/flytekit-jackson/pom.xml index b539fb867..a9204d53c 100644 --- a/flytekit-jackson/pom.xml +++ b/flytekit-jackson/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT flytekit-jackson diff --git a/flytekit-java/pom.xml b/flytekit-java/pom.xml index bed5bfcba..74eb3530d 100644 --- a/flytekit-java/pom.xml +++ b/flytekit-java/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT flytekit-java diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/DefaultSdkWorkflowRegistry.java b/flytekit-java/src/main/java/org/flyte/flytekit/DefaultSdkWorkflowRegistry.java index 10afb9a05..a2dbfda89 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/DefaultSdkWorkflowRegistry.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/DefaultSdkWorkflowRegistry.java @@ -26,8 +26,8 @@ @AutoService(SdkWorkflowRegistry.class) public class DefaultSdkWorkflowRegistry extends SdkWorkflowRegistry { @Override - public List> getWorkflows() { - List> workflows = new ArrayList<>(); + public List> getWorkflows() { + List> workflows = new ArrayList<>(); ServiceLoader.load(SdkWorkflow.class).forEach(workflows::add); return unmodifiableList(workflows); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkAppliedTransform.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkAppliedTransform.java new file mode 100644 index 000000000..b3faa93c6 --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkAppliedTransform.java @@ -0,0 +1,62 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; + +/** + * A {@link SdkTransform} with its inputs applied, so converting it from {@code + * SdkTransform} to a {@code SdkTransform}. + */ +class SdkAppliedTransform extends SdkTransform { + private final SdkTransform transform; + private final OriginalInputT appliedInputs; + + SdkAppliedTransform( + SdkTransform transform, @Nullable OriginalInputT appliedInputs) { + transform.checkNullOnlyVoid(appliedInputs); + this.transform = transform; + this.appliedInputs = appliedInputs; + } + + @Override + public SdkType getInputType() { + return SdkTypes.nulls(); + } + + @Override + public SdkType getOutputType() { + return transform.getOutputType(); + } + + @Override + public String getName() { + return transform.getName(); + } + + @Override + public SdkNode apply( + SdkWorkflowBuilder builder, + String nodeId, + List upstreamNodeIds, + @Nullable SdkNodeMetadata metadata, + Map> inputs) { + return transform.apply(builder, nodeId, upstreamNodeIds, metadata, appliedInputs); + } +} diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java index 6e0b6cc84..16f7654fe 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java @@ -16,9 +16,7 @@ */ package org.flyte.flytekit; -import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; -import static org.flyte.flytekit.MoreCollectors.toUnmodifiableList; +import static java.util.stream.Collectors.toUnmodifiableList; import static org.flyte.flytekit.MoreCollectors.toUnmodifiableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -28,6 +26,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BranchNode; import org.flyte.api.v1.IfElseBlock; @@ -91,10 +90,16 @@ public Node toIdl() { ifElseBlock = ifElseBlock.toBuilder().error(nodeError).build(); } + // inputs in var order for predictability + List inputs = + extraInputs.entrySet().stream() + .sorted(Entry.comparingByKey()) + .map(Entry::getValue) + .collect(toUnmodifiableList()); return Node.builder() .id(nodeId) .branchNode(BranchNode.builder().ifElse(ifElseBlock).build()) - .inputs(List.copyOf(extraInputs.values())) + .inputs(inputs) .upstreamNodeIds(upstreamNodeIds) .build(); } @@ -117,7 +122,7 @@ static class Builder { @CanIgnoreReturnValue Builder addCase(SdkConditionCase case_) { SdkNode sdkNode = - case_.then().apply(builder, case_.name(), emptyList(), /*metadata=*/ null, emptyMap()); + case_.then().apply(builder, case_.name(), List.of(), /*metadata=*/ null, Map.of()); Map> thatOutputs = sdkNode.getOutputBindings(); Map thatOutputTypes = @@ -147,7 +152,7 @@ Builder addCase(SdkConditionCase case_) { } @CanIgnoreReturnValue - Builder addOtherwise(String name, SdkTransform otherwise) { + Builder addOtherwise(String name, SdkTransform otherwise) { if (elseNode != null) { throw new IllegalArgumentException("Duplicate otherwise clause"); } @@ -156,7 +161,7 @@ Builder addOtherwise(String name, SdkTransform otherwise) { throw new IllegalArgumentException(String.format("Duplicate case name [%s]", name)); } - elseNode = otherwise.apply(builder, name, emptyList(), /*metadata=*/ null, emptyMap()); + elseNode = otherwise.apply(builder, name, List.of(), /*metadata=*/ null, Map.of()); caseOutputs.put(name, elseNode.getOutputBindings()); return this; diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkCondition.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkCondition.java index 0f0f469b8..1b9c9d6a2 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkCondition.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkCondition.java @@ -21,32 +21,44 @@ import java.util.Map; import javax.annotation.Nullable; -public class SdkCondition extends SdkTransform { +public class SdkCondition extends SdkTransform { private final SdkType outputType; private final List> cases; private final String otherwiseName; - private final SdkTransform otherwise; + private final SdkTransform otherwise; SdkCondition( List> cases, String otherwiseName, - SdkTransform otherwise) { - this.cases = cases; + SdkTransform otherwise) { + if (cases.isEmpty()) { + throw new IllegalArgumentException("Empty cases on SdkCondition"); + } + this.cases = List.copyOf(cases); this.otherwiseName = otherwiseName; this.otherwise = otherwise; - this.outputType = cases.get(0).then().getOutputType(); + + var firstCase = cases.get(0); + this.outputType = firstCase.then().getOutputType(); } public SdkCondition when( - String name, SdkBooleanExpression condition, SdkTransform then) { - - List> newCases = new ArrayList<>(cases); + String name, SdkBooleanExpression condition, SdkTransform then) { + var newCases = new ArrayList<>(cases); newCases.add(SdkConditionCase.create(name, condition, then)); return new SdkCondition<>(newCases, this.otherwiseName, this.otherwise); } - public SdkCondition otherwise(String name, SdkTransform otherwise) { + public SdkCondition when( + String name, + SdkBooleanExpression condition, + SdkTransform then, + InputT inputs) { + return when(name, condition, new SdkAppliedTransform<>(then, inputs)); + } + + public SdkCondition otherwise(String name, SdkTransform otherwise) { if (this.otherwise != null) { throw new IllegalStateException("Can't set 'otherwise' because it's already set"); } @@ -54,6 +66,16 @@ public SdkCondition otherwise(String name, SdkTransform otherw return new SdkCondition<>(this.cases, name, otherwise); } + public SdkCondition otherwise( + String name, SdkTransform otherwise, InputT inputs) { + return otherwise(name, new SdkAppliedTransform<>(otherwise, inputs)); + } + + @Override + public SdkType getInputType() { + return SdkTypes.nulls(); + } + @Override public SdkType getOutputType() { return outputType; diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditionCase.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditionCase.java index 34bade759..5459a4a1b 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditionCase.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditionCase.java @@ -24,10 +24,10 @@ abstract class SdkConditionCase { abstract SdkBooleanExpression condition(); - abstract SdkTransform then(); + abstract SdkTransform then(); static SdkConditionCase create( - String name, SdkBooleanExpression condition, SdkTransform then) { + String name, SdkBooleanExpression condition, SdkTransform then) { return new AutoValue_SdkConditionCase<>(name, condition, then); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java index 6641bb241..bae2f2aa1 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkConditions.java @@ -16,19 +16,27 @@ */ package org.flyte.flytekit; -import static java.util.Collections.singletonList; import static org.flyte.flytekit.SdkBooleanExpression.ofComparison; +import java.util.List; import org.flyte.api.v1.ComparisonExpression; public class SdkConditions { private SdkConditions() {} - public static SdkCondition when( - String name, SdkBooleanExpression condition, SdkTransform then) { - SdkConditionCase case_ = SdkConditionCase.create(name, condition, then); + public static SdkCondition when( + String name, SdkBooleanExpression condition, SdkTransform then) { + SdkConditionCase case_ = SdkConditionCase.create(name, condition, then); - return new SdkCondition<>(singletonList(case_), null, null); + return new SdkCondition<>(List.of(case_), null, null); + } + + public static SdkCondition when( + String name, + SdkBooleanExpression condition, + SdkTransform then, + InputT inputs) { + return when(name, condition, new SdkAppliedTransform<>(then, inputs)); } public static SdkBooleanExpression eq(SdkBindingData left, SdkBindingData right) { diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkContainerTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkContainerTask.java index 28e38d464..96d14d240 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkContainerTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkContainerTask.java @@ -26,7 +26,7 @@ import org.flyte.api.v1.PartialTaskIdentifier; /** Building block for tasks that execute arbitrary containers. */ -public abstract class SdkContainerTask extends SdkTransform +public abstract class SdkContainerTask extends SdkTransform implements Serializable { private static final long serialVersionUID = 42L; @@ -54,6 +54,7 @@ public String getType() { } /** Specifies task input type. */ + @Override public SdkType getInputType() { return inputType; } @@ -129,7 +130,7 @@ public Map getEnv() { } /** - * Indicates whether the system should attempt to lookup this task's output to avoid duplication + * Indicates whether the system should attempt to look up this task's output to avoid duplication * of work. */ public boolean isCached() { diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java index b7fa66423..1131ee617 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java @@ -21,7 +21,8 @@ import javax.annotation.Nullable; import org.flyte.api.v1.PartialTaskIdentifier; -public abstract class SdkDynamicWorkflowTask extends SdkTransform { +public abstract class SdkDynamicWorkflowTask + extends SdkTransform { private final SdkType inputType; private final SdkType outputType; @@ -36,6 +37,7 @@ public String getType() { return "dynamic"; } + @Override public SdkType getInputType() { return inputType; } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java index 9b7818fc0..709e52c9f 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java @@ -76,7 +76,7 @@ public abstract class SdkLaunchPlan { * @param workflow Workflow to be reference by new {@link SdkLaunchPlan}. * @return the created {@link SdkLaunchPlan}. */ - public static SdkLaunchPlan of(SdkWorkflow workflow) { + public static SdkLaunchPlan of(SdkWorkflow workflow) { SdkWorkflowBuilder wfBuilder = new SdkWorkflowBuilder(); workflow.expand(wfBuilder); return builder() diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPartialTransform.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkMetadataDecoratorTransform.java similarity index 57% rename from flytekit-java/src/main/java/org/flyte/flytekit/SdkPartialTransform.java rename to flytekit-java/src/main/java/org/flyte/flytekit/SdkMetadataDecoratorTransform.java index 814ded2a1..354e6e3d8 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPartialTransform.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkMetadataDecoratorTransform.java @@ -16,72 +16,43 @@ */ package org.flyte.flytekit; -import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; import static java.util.Collections.unmodifiableList; -import static java.util.Collections.unmodifiableMap; import static java.util.Objects.requireNonNull; import java.time.Duration; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Function; import javax.annotation.Nullable; -/** {@link SdkTransform} with partially specified inputs. */ -class SdkPartialTransform extends SdkTransform { - private final SdkTransform transform; - private final Map> fixedInputs; +/** Decorator for {@link SdkTransform} for holding metadata. */ +class SdkMetadataDecoratorTransform extends SdkTransform { + private final SdkTransform transform; private final List extraUpstreamNodeIds; @Nullable private final SdkNodeMetadata metadata; - private SdkPartialTransform( - SdkTransform transform, - Map> fixedInputs, + private SdkMetadataDecoratorTransform( + SdkTransform transform, List extraUpstreamNodeIds, @Nullable SdkNodeMetadata metadata) { this.transform = transform; - this.fixedInputs = unmodifiableMap(new HashMap<>(fixedInputs)); - this.extraUpstreamNodeIds = unmodifiableList(new ArrayList<>(extraUpstreamNodeIds)); + this.extraUpstreamNodeIds = List.copyOf(extraUpstreamNodeIds); this.metadata = metadata; } - static SdkTransform of( - SdkTransform transform, Map> fixedInputs) { - return new SdkPartialTransform<>(transform, fixedInputs, emptyList(), /*metadata=*/ null); + static SdkTransform of( + SdkTransform transform, List extraUpstreamNodeIds) { + return new SdkMetadataDecoratorTransform<>(transform, extraUpstreamNodeIds, /*metadata=*/ null); } - static SdkTransform of(SdkTransform transform, List extraUpstreamNodeIds) { - return new SdkPartialTransform<>( - transform, emptyMap(), extraUpstreamNodeIds, /*metadata=*/ null); - } - - static SdkTransform of(SdkTransform transform, SdkNodeMetadata metadata) { - return new SdkPartialTransform<>(transform, emptyMap(), emptyList(), metadata); + static SdkTransform of( + SdkTransform transform, SdkNodeMetadata metadata) { + return new SdkMetadataDecoratorTransform<>(transform, List.of(), metadata); } @Override - public SdkTransform withInput(String name, SdkBindingData value) { - // isn't necessary to override, but this reduces nesting and gives better error messages - - SdkBindingData existing = fixedInputs.get(name); - if (existing != null) { - String message = - String.format("Duplicate values for input [%s]: [%s], [%s]", name, existing, value); - throw new IllegalArgumentException(message); - } - - Map> newFixedInputs = new HashMap<>(fixedInputs); - newFixedInputs.put(name, value); - - return new SdkPartialTransform<>( - transform, unmodifiableMap(newFixedInputs), extraUpstreamNodeIds, metadata); - } - - @Override - public SdkTransform withUpstreamNode(SdkNode node) { + public SdkTransform withUpstreamNode(SdkNode node) { if (extraUpstreamNodeIds.contains(node.getNodeId())) { throw new IllegalArgumentException( String.format("Duplicate upstream node id [%s]", node.getNodeId())); @@ -90,23 +61,23 @@ public SdkTransform withUpstreamNode(SdkNode node) { List newExtraUpstreamNodeIds = new ArrayList<>(extraUpstreamNodeIds); newExtraUpstreamNodeIds.add(node.getNodeId()); - return new SdkPartialTransform<>( - transform, fixedInputs, unmodifiableList(newExtraUpstreamNodeIds), metadata); + return new SdkMetadataDecoratorTransform<>( + transform, unmodifiableList(newExtraUpstreamNodeIds), metadata); } @Override - public SdkTransform withNameOverride(String name) { + public SdkTransform withNameOverride(String name) { requireNonNull(name, "Name override cannot be null"); SdkNodeMetadata newMetadata = SdkNodeMetadata.builder().name(name).build(); checkForDuplicateMetadata(metadata, newMetadata, SdkNodeMetadata::name, "name"); SdkNodeMetadata mergedMetadata = mergeMetadata(metadata, newMetadata); - return new SdkPartialTransform<>(transform, fixedInputs, extraUpstreamNodeIds, mergedMetadata); + return new SdkMetadataDecoratorTransform<>(transform, extraUpstreamNodeIds, mergedMetadata); } @Override - SdkTransform withNameOverrideIfNotSet(String name) { + SdkTransform withNameOverrideIfNotSet(String name) { if (metadata != null && metadata.name() != null) { return this; } @@ -114,45 +85,38 @@ SdkTransform withNameOverrideIfNotSet(String name) { } @Override - public SdkTransform withTimeoutOverride(Duration timeout) { + public SdkTransform withTimeoutOverride(Duration timeout) { requireNonNull(timeout, "Timeout override cannot be null"); SdkNodeMetadata newMetadata = SdkNodeMetadata.builder().timeout(timeout).build(); checkForDuplicateMetadata(metadata, newMetadata, SdkNodeMetadata::timeout, "timeout"); SdkNodeMetadata mergedMetadata = mergeMetadata(metadata, newMetadata); - return new SdkPartialTransform<>(transform, fixedInputs, extraUpstreamNodeIds, mergedMetadata); + return new SdkMetadataDecoratorTransform<>(transform, extraUpstreamNodeIds, mergedMetadata); } @Override - public SdkType getOutputType() { + public SdkType getOutputType() { return transform.getOutputType(); } + @Override + public SdkType getInputType() { + return transform.getInputType(); + } + @Override public String getName() { return transform.getName(); } @Override - public SdkNode apply( + public SdkNode apply( SdkWorkflowBuilder builder, String nodeId, List upstreamNodeIds, @Nullable SdkNodeMetadata metadata, Map> inputs) { - Map> allInputs = new HashMap<>(fixedInputs); - - inputs.forEach( - (k, v) -> - allInputs.merge( - k, - v, - (v1, v2) -> { - String message = - String.format("Duplicate values for input [%s]: [%s], [%s]", k, v1, v2); - throw new IllegalArgumentException(message); - })); List duplicates = new ArrayList<>(upstreamNodeIds); duplicates.retainAll(extraUpstreamNodeIds); @@ -170,11 +134,7 @@ public SdkNode apply( SdkNodeMetadata mergedMetadata = mergeMetadata(this.metadata, metadata); return transform.apply( - builder, - nodeId, - unmodifiableList(allUpstreamNodeIds), - mergedMetadata, - unmodifiableMap(allInputs)); + builder, nodeId, unmodifiableList(allUpstreamNodeIds), mergedMetadata, inputs); } private static void checkForDuplicateMetadata( diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkNode.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkNode.java index fcd926d5d..ee4ffd003 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkNode.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkNode.java @@ -36,7 +36,6 @@ protected SdkNode(SdkWorkflowBuilder builder) { public SdkBindingData getOutput(String name) { - @SuppressWarnings("unchecked") SdkBindingData output = getOutputBindings().get(name); if (output == null) { @@ -57,13 +56,15 @@ public SdkBindingData getOutput(String name) { public abstract Node toIdl(); - public SdkNode apply(String id, SdkTransform transform) { + // TODO we need a version with no nodeId for consistency with builder + public SdkNode apply( + String nodeId, SdkTransform transform) { // if there are no outputs, explicitly specify dependency to preserve execution order List upstreamNodeIds = getOutputBindings().isEmpty() ? Collections.singletonList(getNodeId()) : Collections.emptyList(); - return builder.applyInternal(id, transform, upstreamNodeIds, getOutputBindings()); + return builder.applyInternal(nodeId, transform, upstreamNodeIds, getOutputs()); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkNodeNamePolicy.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkNodeNamePolicy.java index 4dee35f66..6200532c9 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkNodeNamePolicy.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkNodeNamePolicy.java @@ -17,7 +17,6 @@ package org.flyte.flytekit; import java.util.Locale; -import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Pattern; @@ -25,8 +24,8 @@ /** * Controls the default node id and node name policy when applying {@link SdkTransform} to {@link * SdkWorkflowBuilder}. When using {@link SdkWorkflowBuilder#apply(SdkTransform)} or {@link - * SdkWorkflowBuilder#apply(SdkTransform, Map)} then the node id used would be the one returned by - * {@link #nextNodeId()}. Also, if a node name haven't been set by the user, then {@link + * SdkWorkflowBuilder#apply(SdkTransform, Object)} then the node id used would be the one returned + * by {@link #nextNodeId()}. Also, if a node name haven't been set by the user, then {@link * #toNodeName(String)} would be used. */ class SdkNodeNamePolicy { diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteLaunchPlan.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteLaunchPlan.java index 0c893b939..125340460 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteLaunchPlan.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteLaunchPlan.java @@ -27,7 +27,7 @@ /** Reference to a LaunchPlan deployed in flyte, a remote LaunchPlan. */ @AutoValue -public abstract class SdkRemoteLaunchPlan extends SdkTransform { +public abstract class SdkRemoteLaunchPlan extends SdkTransform { @Nullable public abstract String domain(); @@ -43,6 +43,12 @@ public abstract class SdkRemoteLaunchPlan extends SdkTransform< public abstract SdkType outputs(); + @Override + public SdkType getInputType() { + // TODO consider break backward compatibility to unify the names and avoid this bridge method + return inputs(); + } + @Override public SdkType getOutputType() { // TODO consider break backward compatibility to unify the names and avoid this bridge method diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteTask.java index a55b1208f..ef1716649 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRemoteTask.java @@ -25,7 +25,7 @@ /** Reference to a task deployed in flyte, a remote Task. */ @AutoValue -public abstract class SdkRemoteTask extends SdkTransform { +public abstract class SdkRemoteTask extends SdkTransform { @Nullable public abstract String domain(); @@ -46,6 +46,12 @@ public String version() { public abstract SdkType outputs(); + @Override + public SdkType getInputType() { + // TODO consider break backward compatibility to unify the names and avoid this bridge method + return inputs(); + } + @Override public SdkType getOutputType() { // TODO consider break backward compatibility to unify the names and avoid this bridge method diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTask.java index 3fe97188b..76ff57261 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTask.java @@ -24,7 +24,7 @@ import org.flyte.api.v1.Variable; /** Building block for tasks that execute Java code. */ -public abstract class SdkRunnableTask extends SdkTransform +public abstract class SdkRunnableTask extends SdkTransform implements Serializable { private static final long serialVersionUID = 42L; @@ -51,6 +51,7 @@ public String getType() { return "java-task"; } + @Override public SdkType getInputType() { return inputType; } @@ -79,7 +80,7 @@ public int getRetries() { } /** - * Indicates whether the system should attempt to lookup this task's output to avoid duplication + * Indicates whether the system should attempt to look up this task's output to avoid duplication * of work. */ public boolean isCached() { diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java index 7d28ce37a..ad4a08f07 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTaskNode.java @@ -83,8 +83,10 @@ public String getNodeId() { public Node toIdl() { TaskNode taskNode = TaskNode.builder().referenceId(taskId).build(); + // inputs in var order for predictability List bindings = inputs.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) .map(x -> toBinding(x.getKey(), x.getValue())) .collect(toUnmodifiableList()); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTransform.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTransform.java index 430c178ce..f60c14f31 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTransform.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTransform.java @@ -16,79 +16,73 @@ */ package org.flyte.flytekit; -import static java.util.Collections.singletonList; -import static java.util.Collections.singletonMap; import static java.util.Objects.requireNonNull; import java.time.Duration; -import java.time.Instant; import java.util.List; import java.util.Map; +import java.util.Set; import javax.annotation.Nullable; /** Implementations of {@code SdkTransform} transform {@link SdkNode} into a new one. */ -public abstract class SdkTransform { +public abstract class SdkTransform { - public abstract SdkType getOutputType(); + public abstract SdkType getInputType(); - public abstract SdkNode apply( + public abstract SdkType getOutputType(); + + public final SdkNode apply( SdkWorkflowBuilder builder, String nodeId, List upstreamNodeIds, @Nullable SdkNodeMetadata metadata, - Map> inputs); - - public SdkTransform withInput(String name, String value) { - return withInput(name, SdkBindingData.ofString(value)); - } - - public SdkTransform withInput(String name, long value) { - return withInput(name, SdkBindingData.ofInteger(value)); - } - - public SdkTransform withInput(String name, Instant value) { - return withInput(name, SdkBindingData.ofDatetime(value)); - } - - public SdkTransform withInput(String name, Duration value) { - return withInput(name, SdkBindingData.ofDuration(value)); - } - - public SdkTransform withInput(String name, boolean value) { - return withInput(name, SdkBindingData.ofBoolean(value)); + @Nullable InputT inputs) { + checkNullOnlyVoid(inputs); + var inputsBindings = getInputType().toSdkBindingMap(inputs); + return apply(builder, nodeId, upstreamNodeIds, metadata, inputsBindings); } - public SdkTransform withInput(String name, double value) { - return withInput(name, SdkBindingData.ofFloat(value)); - } - - public SdkTransform withInput(String name, SdkBindingData value) { - return SdkPartialTransform.of(this, singletonMap(name, value)); - } + abstract SdkNode apply( + SdkWorkflowBuilder builder, + String nodeId, + List upstreamNodeIds, + @Nullable SdkNodeMetadata metadata, + Map> inputs); - public SdkTransform withUpstreamNode(SdkNode node) { - return SdkPartialTransform.of(this, singletonList(node.getNodeId())); + public SdkTransform withUpstreamNode(SdkNode node) { + return SdkMetadataDecoratorTransform.of(this, List.of(node.getNodeId())); } - public SdkTransform withNameOverride(String name) { + public SdkTransform withNameOverride(String name) { requireNonNull(name, "Name override cannot be null"); SdkNodeMetadata metadata = SdkNodeMetadata.builder().name(name).build(); - return SdkPartialTransform.of(this, metadata); + return SdkMetadataDecoratorTransform.of(this, metadata); } - SdkTransform withNameOverrideIfNotSet(String name) { + SdkTransform withNameOverrideIfNotSet(String name) { return withNameOverride(name); } - public SdkTransform withTimeoutOverride(Duration timeout) { + public SdkTransform withTimeoutOverride(Duration timeout) { requireNonNull(timeout, "Timeout override cannot be null"); SdkNodeMetadata metadata = SdkNodeMetadata.builder().timeout(timeout).build(); - return SdkPartialTransform.of(this, metadata); + return SdkMetadataDecoratorTransform.of(this, metadata); } public String getName() { return getClass().getName(); } + + void checkNullOnlyVoid(@Nullable InputT inputs) { + Set variableNames = getInputType().variableNames(); + boolean hasProperties = !variableNames.isEmpty(); + if (inputs == null && hasProperties) { + throw new IllegalArgumentException( + "Null supplied as input for a transform with variables: " + variableNames); + } else if (inputs != null && !hasProperties) { + throw new IllegalArgumentException("Null input expected for a transform with no variables"); + } + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java index 72e3dd517..929eaee9e 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java @@ -17,6 +17,7 @@ package org.flyte.flytekit; import java.util.Map; +import java.util.Set; import org.flyte.api.v1.Literal; import org.flyte.api.v1.Variable; @@ -30,5 +31,9 @@ public abstract class SdkType { public abstract Map getVariableMap(); + public Set variableNames() { + return Set.copyOf(getVariableMap().keySet()); + } + public abstract Map> toSdkBindingMap(T value); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java index a2d628ace..eafd299ab 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java @@ -16,24 +16,26 @@ */ package org.flyte.flytekit; -import java.util.Collections; import java.util.Map; import org.flyte.api.v1.Literal; import org.flyte.api.v1.Variable; -/** An utility class for creating {@link SdkType} objects for different types. */ +/** A utility class for creating {@link SdkType} objects for different types. */ public class SdkTypes { + + private static final VoidSdkType VOID_SDK_TYPE = new VoidSdkType(); + private SdkTypes() {} public static SdkType nulls() { - return new VoidSdkType(); + return VOID_SDK_TYPE; } private static class VoidSdkType extends SdkType { @Override public Map toLiteralMap(Void value) { - return Collections.emptyMap(); + return Map.of(); } @Override @@ -48,12 +50,12 @@ public Void promiseFor(String nodeId) { @Override public Map getVariableMap() { - return Collections.emptyMap(); + return Map.of(); } @Override public Map> toSdkBindingMap(Void value) { - return Collections.emptyMap(); + return Map.of(); } } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflow.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflow.java index 91b3e3a94..74d27ad7a 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflow.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflow.java @@ -25,10 +25,12 @@ import org.flyte.api.v1.WorkflowNode; import org.flyte.api.v1.WorkflowTemplate; -public abstract class SdkWorkflow extends SdkTransform { +public abstract class SdkWorkflow extends SdkTransform { + private final SdkType inputType; private final SdkType outputType; - protected SdkWorkflow(SdkType outputType) { + protected SdkWorkflow(SdkType inputType, SdkType outputType) { + this.inputType = inputType; this.outputType = outputType; } @@ -73,6 +75,11 @@ public SdkNode apply( builder, nodeId, upstreamNodeIds, metadata, workflowNode, inputs, outputs, promise); } + @Override + public SdkType getInputType() { + return inputType; + } + @Override public SdkType getOutputType() { return outputType; diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java index 4c78b8915..bd77474ac 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowBuilder.java @@ -17,7 +17,6 @@ package org.flyte.flytekit; import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; import static java.util.Collections.unmodifiableMap; import static org.flyte.api.v1.Node.START_NODE_ID; @@ -28,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import javax.annotation.Nullable; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.SimpleType; import org.flyte.api.v1.WorkflowTemplate; @@ -57,29 +57,30 @@ public SdkWorkflowBuilder() { this.sdkNodeNamePolicy = sdkNodeNamePolicy; } - public SdkNode apply(String nodeId, SdkTransform transform) { - return apply(nodeId, transform, emptyMap()); + public SdkNode apply( + String nodeId, SdkTransform transformWithoutInputs) { + return applyInternal(nodeId, transformWithoutInputs, emptyList(), null); } - public SdkNode apply( - String nodeId, SdkTransform transform, Map> inputs) { + public SdkNode apply( + String nodeId, SdkTransform transform, InputT inputs) { return applyInternal(nodeId, transform, emptyList(), inputs); } - public SdkNode apply(SdkTransform transform) { - return apply(/*nodeId=*/ null, transform, emptyMap()); + public SdkNode apply(SdkTransform transformWithoutInputs) { + return apply(/*nodeId=*/ (String) null, transformWithoutInputs); } - public SdkNode apply(SdkTransform transform, Map> inputs) { - return applyInternal(/*nodeId=*/ null, transform, emptyList(), inputs); + public SdkNode apply( + SdkTransform transform, InputT inputs) { + return apply(/*nodeId=*/ null, transform, inputs); } - protected SdkNode applyInternal( + protected SdkNode applyInternal( String nodeId, - SdkTransform transform, + SdkTransform transform, List upstreamNodeIds, - Map> inputs) { - + @Nullable InputT inputs) { String actualNodeId = Objects.requireNonNullElseGet(nodeId, sdkNodeNamePolicy::nextNodeId); if (nodes.containsKey(actualNodeId)) { @@ -96,7 +97,7 @@ protected SdkNode applyInternal( Objects.requireNonNullElseGet( nodeId, () -> sdkNodeNamePolicy.toNodeName(transform.getName())); - SdkNode sdkNode = + SdkNode sdkNode = transform .withNameOverrideIfNotSet(fallbackNodeName) .apply(this, actualNodeId, upstreamNodeIds, null, inputs); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowNode.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowNode.java index cb38f1eb9..151f1094c 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowNode.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowNode.java @@ -71,8 +71,10 @@ public String getNodeId() { @Override public Node toIdl() { + // inputs in var order for predictability List inputBindings = this.inputBindings.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) .map(x -> toBinding(x.getKey(), x.getValue())) .collect(toUnmodifiableList()); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowRegistry.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowRegistry.java index c721d7547..ab2e3c460 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowRegistry.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowRegistry.java @@ -24,14 +24,14 @@ public abstract class SdkWorkflowRegistry { - public abstract List> getWorkflows(); + public abstract List> getWorkflows(); - public static List> loadAll() { + public static List> loadAll() { return loadAll(ServiceLoader.load(SdkWorkflowRegistry.class)); } - static List> loadAll(Iterable loader) { - List> workflows = new ArrayList<>(); + static List> loadAll(Iterable loader) { + List> workflows = new ArrayList<>(); for (SdkWorkflowRegistry registry : loader) { workflows.addAll(registry.getWorkflows()); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowTemplateRegistrar.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowTemplateRegistrar.java index 072a3f64a..abee364e1 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowTemplateRegistrar.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkWorkflowTemplateRegistrar.java @@ -46,12 +46,12 @@ public Map load( } Map load( - SdkConfig sdkConfig, List> sdkWorkflows) { + SdkConfig sdkConfig, List> sdkWorkflows) { LOG.fine("Discovering SdkWorkflow"); Map workflows = new HashMap<>(); - for (SdkWorkflow sdkWorkflow : sdkWorkflows) { + for (SdkWorkflow sdkWorkflow : sdkWorkflows) { String name = sdkWorkflow.getName(); WorkflowIdentifier workflowId = WorkflowIdentifier.builder() diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SimpleSdkLaunchPlanRegistry.java b/flytekit-java/src/main/java/org/flyte/flytekit/SimpleSdkLaunchPlanRegistry.java index 7a7f65ab8..343b3653c 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SimpleSdkLaunchPlanRegistry.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SimpleSdkLaunchPlanRegistry.java @@ -16,9 +16,6 @@ */ package org.flyte.flytekit; -import static java.util.Collections.unmodifiableList; - -import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -37,14 +34,14 @@ public void registerLaunchPlan(SdkLaunchPlan launchPlan) { } public void registerDefaultLaunchPlans() { - List> workflows = SdkWorkflowRegistry.loadAll(); + List> workflows = SdkWorkflowRegistry.loadAll(); registerDefaultLaunchPlans(workflows); } // Visible for testing - void registerDefaultLaunchPlans(List> workflows) { - for (SdkWorkflow sdkWorkflow : workflows) { + void registerDefaultLaunchPlans(List> workflows) { + for (SdkWorkflow sdkWorkflow : workflows) { SdkLaunchPlan defaultLaunchPlan = SdkLaunchPlan.of(sdkWorkflow); registerLaunchPlan(defaultLaunchPlan); } @@ -52,6 +49,6 @@ void registerDefaultLaunchPlans(List> workflows) { @Override public List getLaunchPlans() { - return unmodifiableList(new ArrayList<>(launchPlans.values())); + return List.copyOf(launchPlans.values()); } } diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java index d30217d98..1921f3594 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java @@ -34,7 +34,6 @@ import java.time.Duration; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import org.flyte.api.v1.CronSchedule; @@ -52,15 +51,8 @@ class SdkLaunchPlanRegistrarTest { - private static final Map ENV; - - static { - HashMap env = new HashMap<>(); - env.put(PROJECT_ENV_VAR, "project"); - env.put(DOMAIN_ENV_VAR, "domain"); - env.put(VERSION_ENV_VAR, "version"); - ENV = Collections.unmodifiableMap(env); - } + private static final Map ENV = + Map.of(PROJECT_ENV_VAR, "project", DOMAIN_ENV_VAR, "domain", VERSION_ENV_VAR, "version"); private final SdkLaunchPlanRegistrar registrar = new SdkLaunchPlanRegistrar(); @@ -257,10 +249,10 @@ public List getLaunchPlans() { } } - public static class TestWorkflow extends SdkWorkflow { + public static class TestWorkflow extends SdkWorkflow { public TestWorkflow() { - super(SdkTypes.nulls()); + super(SdkTypes.nulls(), SdkTypes.nulls()); } @Override @@ -270,10 +262,10 @@ public void expand(SdkWorkflowBuilder builder) { } } - public static class OtherTestWorkflow extends SdkWorkflow { + public static class OtherTestWorkflow extends SdkWorkflow { public OtherTestWorkflow() { - super(SdkTypes.nulls()); + super(SdkTypes.nulls(), SdkTypes.nulls()); } @Override diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java index 1f1376098..00ee49fa0 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java @@ -27,8 +27,10 @@ import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertThrows; +import com.google.auto.value.AutoValue; import java.time.Duration; import java.time.Instant; +import java.util.Map; import java.util.function.Consumer; import java.util.stream.Stream; import org.flyte.api.v1.Literal; @@ -130,6 +132,8 @@ void shouldAddDefaultInputs() { SdkLaunchPlan plan = SdkLaunchPlan.of(new TestWorkflow()) + // 😔 this is still untyped but the whole point is to be able to partially specify + // inputs .withDefaultInput("integer", 123L) .withDefaultInput("float", 1.23) .withDefaultInput("string", "123") @@ -237,10 +241,10 @@ private Parameter asParameter(Primitive primitive, SimpleType simpleType) { Literal.ofScalar(Scalar.ofPrimitive(primitive))); } - private static class TestWorkflow extends SdkWorkflow { + private static class TestWorkflow extends SdkWorkflow { private TestWorkflow() { - super(SdkTypes.nulls()); + super(new TestWorkflowInput.SdkType(), SdkTypes.nulls()); } @Override @@ -256,10 +260,119 @@ public void expand(SdkWorkflowBuilder builder) { } } - private static class NoInputsTestWorkflow extends SdkWorkflow { + @AutoValue + abstract static class TestWorkflowInput { + abstract SdkBindingData integer(); + + abstract SdkBindingData _float(); + + abstract SdkBindingData string(); + + abstract SdkBindingData _boolean(); + + abstract SdkBindingData datetime(); + + abstract SdkBindingData duration(); + + abstract SdkBindingData a(); + + abstract SdkBindingData b(); + + public static TestWorkflowInput create( + SdkBindingData integer, + SdkBindingData _float, + SdkBindingData string, + SdkBindingData _boolean, + SdkBindingData datetime, + SdkBindingData duration, + SdkBindingData a, + SdkBindingData b) { + return new AutoValue_SdkLaunchPlanTest_TestWorkflowInput( + integer, _float, string, _boolean, datetime, duration, a, b); + } + + public static class SdkType extends org.flyte.flytekit.SdkType { + + private static final String INTEGER = "integer"; + private static final String FLOAT = "_float"; + private static final String STRING = "string"; + private static final String BOOLEAN = "_boolean"; + private static final String DATETIME = "datetime"; + private static final String DURATION = "duration"; + private static final String A = "a"; + private static final String B = "b"; + + @Override + public Map toLiteralMap(TestWorkflowInput value) { + return Map.ofEntries( + Map.entry(INTEGER, Literals.ofInteger(value.integer().get())), + Map.entry(FLOAT, Literals.ofFloat(value._float().get())), + Map.entry(STRING, Literals.ofString(value.string().get())), + Map.entry(BOOLEAN, Literals.ofBoolean(value._boolean().get())), + Map.entry(DATETIME, Literals.ofDatetime(value.datetime().get())), + Map.entry(DURATION, Literals.ofDuration(value.duration().get())), + Map.entry(A, Literals.ofInteger(value.a().get())), + Map.entry(B, Literals.ofInteger(value.b().get()))); + } + + @Override + public TestWorkflowInput fromLiteralMap(Map value) { + return create( + SdkBindingData.ofInteger(value.get(INTEGER).scalar().primitive().integerValue()), + SdkBindingData.ofFloat(value.get(FLOAT).scalar().primitive().floatValue()), + SdkBindingData.ofString(value.get(STRING).scalar().primitive().stringValue()), + SdkBindingData.ofBoolean(value.get(BOOLEAN).scalar().primitive().booleanValue()), + SdkBindingData.ofDatetime(value.get(DATETIME).scalar().primitive().datetime()), + SdkBindingData.ofDuration(value.get(DURATION).scalar().primitive().duration()), + SdkBindingData.ofInteger(value.get(A).scalar().primitive().integerValue()), + SdkBindingData.ofInteger(value.get(B).scalar().primitive().integerValue())); + } + + @Override + public TestWorkflowInput promiseFor(String nodeId) { + return create( + SdkBindingData.ofOutputReference(nodeId, INTEGER, LiteralTypes.INTEGER), + SdkBindingData.ofOutputReference(nodeId, FLOAT, LiteralTypes.FLOAT), + SdkBindingData.ofOutputReference(nodeId, STRING, LiteralTypes.STRING), + SdkBindingData.ofOutputReference(nodeId, BOOLEAN, LiteralTypes.BOOLEAN), + SdkBindingData.ofOutputReference(nodeId, DATETIME, LiteralTypes.DATETIME), + SdkBindingData.ofOutputReference(nodeId, DURATION, LiteralTypes.DURATION), + SdkBindingData.ofOutputReference(nodeId, A, LiteralTypes.INTEGER), + SdkBindingData.ofOutputReference(nodeId, B, LiteralTypes.INTEGER)); + } + + @Override + public Map getVariableMap() { + return Map.ofEntries( + Map.entry(INTEGER, Variable.builder().literalType(LiteralTypes.INTEGER).build()), + Map.entry(FLOAT, Variable.builder().literalType(LiteralTypes.FLOAT).build()), + Map.entry(STRING, Variable.builder().literalType(LiteralTypes.STRING).build()), + Map.entry(BOOLEAN, Variable.builder().literalType(LiteralTypes.BOOLEAN).build()), + Map.entry(DATETIME, Variable.builder().literalType(LiteralTypes.DATETIME).build()), + Map.entry(DURATION, Variable.builder().literalType(LiteralTypes.DURATION).build()), + Map.entry(A, Variable.builder().literalType(LiteralTypes.INTEGER).build()), + Map.entry(B, Variable.builder().literalType(LiteralTypes.INTEGER).build())); + } + + @Override + public Map> toSdkBindingMap(TestWorkflowInput value) { + return Map.ofEntries( + Map.entry(INTEGER, value.integer()), + Map.entry(FLOAT, value._float()), + Map.entry(STRING, value.string()), + Map.entry(BOOLEAN, value._boolean()), + Map.entry(DATETIME, value.datetime()), + Map.entry(DURATION, value.duration()), + Map.entry(A, value.a()), + Map.entry(B, value.b())); + } + } + } + + private static class NoInputsTestWorkflow extends SdkWorkflow { private NoInputsTestWorkflow() { - super(SdkTypes.nulls()); + super(SdkTypes.nulls(), SdkTypes.nulls()); } @Override diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteLaunchPlanTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteLaunchPlanTest.java index e60ddf9b6..a9636c733 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteLaunchPlanTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteLaunchPlanTest.java @@ -24,8 +24,6 @@ import static org.mockito.Mockito.mock; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BindingData; import org.flyte.api.v1.Node; @@ -38,9 +36,8 @@ public class SdkRemoteLaunchPlanTest { @Test void applyShouldReturnASdkWorkflowNode() { - Map> inputs = new HashMap<>(); - inputs.put("a", SdkBindingData.ofInteger(1)); - inputs.put("b", SdkBindingData.ofInteger(2)); + var inputs = + TestPairIntegerInput.create(SdkBindingData.ofInteger(1), SdkBindingData.ofInteger(2)); SdkRemoteLaunchPlan remoteLaunchPlan = new TestSdkRemoteLaunchPlan(); diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteTaskTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteTaskTest.java index a15a24067..ab8405452 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteTaskTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkRemoteTaskTest.java @@ -24,8 +24,6 @@ import static org.mockito.Mockito.mock; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BindingData; import org.flyte.api.v1.Node; @@ -39,9 +37,8 @@ class SdkRemoteTaskTest { @Test void applyShouldReturnASdkTaskNode() { - Map> inputs = new HashMap<>(); - inputs.put("a", SdkBindingData.ofInteger(1)); - inputs.put("b", SdkBindingData.ofInteger(2)); + var inputs = + TestPairIntegerInput.create(SdkBindingData.ofInteger(1), SdkBindingData.ofInteger(2)); SdkRemoteTask remoteTask = new TestSdkRemoteTask(); diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkTransformTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkTransformTest.java new file mode 100644 index 000000000..ed6c1a336 --- /dev/null +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkTransformTest.java @@ -0,0 +1,166 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.verify; + +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.flyte.api.v1.Literal; +import org.flyte.api.v1.Variable; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class SdkTransformTest { + + @Mock private SdkNode mockResponse; + + @Test + void applyShouldPropagateCallToSubClasses() { + var transform = Mockito.spy(new TransformWithInputs()); + var builder = new SdkWorkflowBuilder(); + var nodeId = "node"; + var upstreamNodeIds = List.of("upstream-node"); + var metadata = SdkNodeMetadata.builder().name("fancy-name").build(); + var in = SdkBindingData.ofInteger(1); + var inputs = TestUnaryIntegerInput.create(in); + var inputsBindings = Map.>of("in", in); + + transform.apply(builder, nodeId, upstreamNodeIds, metadata, inputs); + + verify(transform).apply(builder, nodeId, upstreamNodeIds, metadata, inputsBindings); + } + + @Test + void applyShouldRejectCallsForNullInputsForTypesWithVariables() { + var exception = + assertThrows( + IllegalArgumentException.class, + () -> + new TransformWithInputs() + .apply( + new SdkWorkflowBuilder(), + "node", + List.of(), + null, + (TestUnaryIntegerInput) null)); + + assertThat( + exception.getMessage(), + equalTo("Null supplied as input for a transform with variables: [in]")); + } + + @Test + void applyShouldRejectCallsForNonNullInputsForTypesWithoutVariables() { + var exception = + assertThrows( + IllegalArgumentException.class, + () -> + new TransformWithoutInputs() + .apply(new SdkWorkflowBuilder(), "node", List.of(), null, "not a null value")); + + assertThat( + exception.getMessage(), equalTo("Null input expected for a transform with no variables")); + } + + private class TransformWithInputs extends SdkTransform { + + @Override + public SdkType getInputType() { + return new TestUnaryIntegerInput.SdkType(); + } + + @Override + public SdkType getOutputType() { + return SdkTypes.nulls(); + } + + @CanIgnoreReturnValue + @Override + SdkNode apply( + SdkWorkflowBuilder builder, + String nodeId, + List upstreamNodeIds, + @Nullable SdkNodeMetadata metadata, + Map> inputs) { + return mockResponse; + } + } + + private class TransformWithoutInputs extends SdkTransform { + + @Override + public SdkType getInputType() { + return new CustomNoVariableType(); + } + + @Override + public SdkType getOutputType() { + return SdkTypes.nulls(); + } + + @CanIgnoreReturnValue + @Override + SdkNode apply( + SdkWorkflowBuilder builder, + String nodeId, + List upstreamNodeIds, + @Nullable SdkNodeMetadata metadata, + Map> inputs) { + return mockResponse; + } + } + + // No rational user would write a SdkType implementation like this, but it allows us to test + // the corner case that types without variables should accept only null values + private static class CustomNoVariableType extends SdkType { + + @Override + public Map toLiteralMap(Object value) { + return Map.of(); + } + + @Override + public Object fromLiteralMap(Map value) { + return null; + } + + @Override + public Object promiseFor(String nodeId) { + return null; + } + + @Override + public Map getVariableMap() { + return Map.of(); + } + + @Override + public Map> toSdkBindingMap(Object value) { + return Map.of(); + } + } +} diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java index da063ffee..827b67bbd 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowBuilderTest.java @@ -135,7 +135,14 @@ void testTimes4WorkflowIdl() { .nodes(List.of(node0, node1)) .build(); - assertEquals(expected, builder.toIdlTemplate()); + WorkflowTemplate actual = builder.toIdlTemplate(); + assertEquals(expected.interface_(), actual.interface_()); + assertEquals(expected.metadata(), actual.metadata()); + assertEquals(expected.outputs(), actual.outputs()); + assertEquals(expected.nodes().get(0), actual.nodes().get(0)); + assertEquals(expected.nodes().get(1), actual.nodes().get(1)); + assertEquals(expected.nodes(), actual.nodes()); + assertEquals(expected, actual); } @Test @@ -225,15 +232,14 @@ void testDuplicateNodeId() { SdkBindingData a = builder.inputOfInteger("a"); SdkBindingData b = builder.inputOfInteger("b"); + TestPairIntegerInput input = TestPairIntegerInput.create(a, b); - builder.apply("node-1", new MultiplicationTask().withInput("a", a).withInput("b", b)); + builder.apply("node-1", new MultiplicationTask(), input); CompilerException e = assertThrows( CompilerException.class, - () -> - builder.apply( - "node-1", new MultiplicationTask().withInput("a", a).withInput("b", b))); + () -> builder.apply("node-1", new MultiplicationTask(), input)); assertEquals( "Failed to build workflow with errors:\n" @@ -243,14 +249,15 @@ void testDuplicateNodeId() { @ParameterizedTest @MethodSource("createTransform") - void testVariableNameNotFound_output(SdkTransform transform) { + void testVariableNameNotFound_output( + SdkTransform transform) { SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); SdkBindingData a = builder.inputOfInteger("a"); SdkBindingData b = builder.inputOfInteger("b"); + TestPairIntegerInput input = TestPairIntegerInput.create(a, b); - SdkNode node1 = - builder.apply("node-1", transform.withInput("a", a).withInput("b", b)); + SdkNode node1 = builder.apply("node-1", transform, input); CompilerException e = assertThrows(CompilerException.class, () -> node1.getOutput("foo")); @@ -262,79 +269,18 @@ void testVariableNameNotFound_output(SdkTransform transf @ParameterizedTest @MethodSource("createTransform") - void testVariableNameNotFound_input(SdkTransform transform) { - SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - - SdkBindingData a = builder.inputOfInteger("a"); - SdkBindingData b = builder.inputOfInteger("b"); - SdkBindingData foo = builder.inputOfInteger("foo"); - - CompilerException e = - assertThrows( - CompilerException.class, - () -> - builder.apply( - "node-1", transform.withInput("a", a).withInput("b", b).withInput("foo", foo))); - - assertEquals( - "Failed to build workflow with errors:\n" - + "Error 0: Code: VARIABLE_NAME_NOT_FOUND, Node Id: node-1, Description: Variable [foo] not found on node [node-1].", - e.getMessage()); - } - - @ParameterizedTest - @MethodSource("createTransform") - void testParameterNotBound(SdkTransform transform) { - SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - - SdkBindingData a = builder.inputOfInteger("a"); - - CompilerException e = - assertThrows( - CompilerException.class, () -> builder.apply("node-1", transform.withInput("a", a))); - - assertEquals( - "Failed to build workflow with errors:\n" - + "Error 0: Code: PARAMETER_NOT_BOUND, Node Id: node-1, Description: Parameter not bound [b].", - e.getMessage()); - } - - @ParameterizedTest - @MethodSource("createTransform") - void tesMismatchingTypes(SdkTransform transform) { + void testUpstreamNode_withUpstreamNode( + SdkTransform transform) { SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - SdkBindingData a = builder.inputOfString("a"); - SdkBindingData b = builder.inputOfString("b"); + SdkBindingData a = builder.inputOfInteger("el0"); + SdkBindingData b = builder.inputOfInteger("el1"); + TestPairIntegerInput input = TestPairIntegerInput.create(a, b); - CompilerException e = - assertThrows( - CompilerException.class, - () -> builder.apply("node-1", transform.withInput("a", a).withInput("b", b))); - - // TODO need to implement pretty-printer for types, not it isn't super readable - - assertEquals( - "Failed to build workflow with errors:\n" - + "Error 0: Code: MISMATCHING_TYPES, Node Id: node-1, Description: Variable [a] (type [LiteralType{simpleType=STRING}]) doesn't match expected type [LiteralType{simpleType=INTEGER}].\n" - + "Error 1: Code: MISMATCHING_TYPES, Node Id: node-1, Description: Variable [b] (type [LiteralType{simpleType=STRING}]) doesn't match expected type [LiteralType{simpleType=INTEGER}].", - e.getMessage()); - } - - @ParameterizedTest - @MethodSource("createTransform") - void testUpstreamNode_withUpstreamNode(SdkTransform transform) { - SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - - SdkBindingData el0 = builder.inputOfInteger("el0"); - SdkBindingData el1 = builder.inputOfInteger("el1"); - - SdkNode el2 = - builder.apply("el2", transform.withInput("a", el0).withInput("b", el1)); + SdkNode el2 = builder.apply("el2", transform, input); SdkNode el3 = - builder.apply( - "el3", transform.withUpstreamNode(el2).withInput("a", el0).withInput("b", el1)); + builder.apply("el3", transform.withUpstreamNode(el2), input); assertEquals(singletonList("el2"), el3.toIdl().upstreamNodeIds()); } @@ -381,24 +327,24 @@ void testUpstreamNode_duplicateWithNode() { @ParameterizedTest @MethodSource("createTransform") - void testNodeMetadataOverrides(SdkTransform transform) { + void testNodeMetadataOverrides( + SdkTransform transform) { SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - SdkBindingData el0 = builder.inputOfInteger("el0"); - SdkBindingData el1 = builder.inputOfInteger("el1"); + SdkBindingData a = builder.inputOfInteger("el0"); + SdkBindingData b = builder.inputOfInteger("el1"); + TestPairIntegerInput input = TestPairIntegerInput.create(a, b); - SdkNode el2 = - builder.apply("el2", transform.withInput("a", el0).withInput("b", el1)); + SdkNode el2 = builder.apply("el2", transform, input); SdkNode el3 = builder.apply( "el3", transform .withUpstreamNode(el2) - .withInput("a", el0) - .withInput("b", el1) .withNameOverride("fancy-el3") - .withTimeoutOverride(Duration.ofMinutes(15))); + .withTimeoutOverride(Duration.ofMinutes(15)), + input); assertThat( el3.toIdl().metadata(), @@ -407,14 +353,15 @@ void testNodeMetadataOverrides(SdkTransform transform) { @ParameterizedTest @MethodSource("createTransform") - void testNodeMetadataOverrides_duplicate(SdkTransform transform) { + void testNodeMetadataOverrides_duplicate( + SdkTransform transform) { SdkWorkflowBuilder builder = new SdkWorkflowBuilder(); - SdkBindingData el0 = builder.inputOfInteger("el0"); - SdkBindingData el1 = builder.inputOfInteger("el1"); + SdkBindingData a = builder.inputOfInteger("el0"); + SdkBindingData b = builder.inputOfInteger("el1"); + TestPairIntegerInput input = TestPairIntegerInput.create(a, b); - SdkNode el2 = - builder.apply("el2", transform.withInput("a", el0).withInput("b", el1)); + SdkNode el2 = builder.apply("el2", transform, input); IllegalArgumentException ex = assertThrows( @@ -424,10 +371,9 @@ void testNodeMetadataOverrides_duplicate(SdkTransform tr "el3", transform .withUpstreamNode(el2) - .withInput("a", el0) - .withInput("b", el1) .withNameOverride("fancy-el3") - .withNameOverride("another-name"))); + .withNameOverride("another-name"), + input)); assertThat(ex.getMessage(), equalTo("Duplicate values for metadata: name")); } @@ -461,8 +407,8 @@ void testInputOf() { builder.inputOfInteger("input6")); } - static List> createTransform() { - return asList(new MultiplicationTask(), new MultiplicationWorkflow()); + static List> createTransform() { + return List.of(new MultiplicationTask(), new MultiplicationWorkflow()); } private TypedInterface expectedInterface() { @@ -490,24 +436,26 @@ private List expectedOutputs(String nodeId) { .build()); } - private static class Times4Workflow extends SdkWorkflow { + private static class Times4Workflow + extends SdkWorkflow { protected Times4Workflow() { - super(new TestUnaryIntegerOutput.SdkType()); + super(new TestUnaryIntegerInput.SdkType(), new TestUnaryIntegerOutput.SdkType()); } @Override public void expand(SdkWorkflowBuilder builder) { SdkBindingData in = builder.inputOfInteger("in", "Enter value to square"); SdkBindingData two = literalOfInteger(2L); + SdkBindingData out1 = builder - .apply(new MultiplicationTask().withInput("a", in).withInput("b", two)) + .apply(new MultiplicationTask(), TestPairIntegerInput.create(in, two)) .getOutputs() .o(); SdkBindingData out2 = builder - .apply(new MultiplicationTask().withInput("a", out1).withInput("b", two)) + .apply(new MultiplicationTask(), TestPairIntegerInput.create(out1, two)) .getOutputs() .o(); @@ -515,10 +463,11 @@ public void expand(SdkWorkflowBuilder builder) { } } - private static class ConditionalWorkflow extends SdkWorkflow { + private static class ConditionalWorkflow + extends SdkWorkflow { private ConditionalWorkflow() { - super(new TestUnaryIntegerOutput.SdkType()); + super(new TestUnaryIntegerInput.SdkType(), new TestUnaryIntegerOutput.SdkType()); } @Override @@ -532,7 +481,8 @@ public void expand(SdkWorkflowBuilder builder) { SdkConditions.when( "neq", SdkConditions.neq(in, two), - new MultiplicationTask().withInput("a", in).withInput("b", two))); + new MultiplicationTask(), + TestPairIntegerInput.create(in, two))); builder.output("o", out.getOutputs().o()); } @@ -552,10 +502,11 @@ public TestUnaryIntegerOutput run(TestPairIntegerInput input) { } } - static class MultiplicationWorkflow extends SdkWorkflow { + static class MultiplicationWorkflow + extends SdkWorkflow { MultiplicationWorkflow() { - super(new TestUnaryIntegerOutput.SdkType()); + super(new TestPairIntegerInput.SdkType(), new TestUnaryIntegerOutput.SdkType()); } @Override @@ -564,7 +515,7 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData b = builder.inputOfInteger("b"); SdkNode multiply = - builder.apply("multiply", new MultiplicationTask().withInput("a", a).withInput("b", b)); + builder.apply("multiply", new MultiplicationTask(), TestPairIntegerInput.create(a, b)); builder.output("c", multiply.getOutputs().o()); } diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowRegistryTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowRegistryTest.java index 0b5b6576a..51ac84aba 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowRegistryTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowRegistryTest.java @@ -26,11 +26,11 @@ public class SdkWorkflowRegistryTest { @Test public void testLoadAll() { - SdkWorkflow workflow1 = new TestWorkflow(); - SdkWorkflow workflow2 = new TestWorkflow(); - SdkWorkflow workflow3 = new TestWorkflow(); + SdkWorkflow workflow1 = new TestWorkflow(); + SdkWorkflow workflow2 = new TestWorkflow(); + SdkWorkflow workflow3 = new TestWorkflow(); - List> workflows = + List> workflows = SdkWorkflowRegistry.loadAll( List.of( new SimpleSdkWorkflowRegistry(List.of(workflow1)), @@ -40,21 +40,21 @@ public void testLoadAll() { } static class SimpleSdkWorkflowRegistry extends SdkWorkflowRegistry { - private final List> workflows; + private final List> workflows; - public SimpleSdkWorkflowRegistry(List> workflows) { + public SimpleSdkWorkflowRegistry(List> workflows) { this.workflows = workflows; } @Override - public List> getWorkflows() { + public List> getWorkflows() { return workflows; } } - private static class TestWorkflow extends SdkWorkflow { + private static class TestWorkflow extends SdkWorkflow { private TestWorkflow() { - super(SdkTypes.nulls()); + super(SdkTypes.nulls(), SdkTypes.nulls()); } @Override diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowTemplateRegistrarTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowTemplateRegistrarTest.java index 78ccc77be..e51830cd9 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowTemplateRegistrarTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowTemplateRegistrarTest.java @@ -33,7 +33,7 @@ public void testLoad() { SdkConfig sdkConfig = SdkConfig.builder().domain("domain").project("project").version("version").build(); - List> sdkWorkflows = + List> sdkWorkflows = Arrays.asList(new TestWorkflow("workflow1"), new TestWorkflow("workflow2")); Map workflows = @@ -58,11 +58,11 @@ public void testLoad() { .build())); } - private static class TestWorkflow extends SdkWorkflow { + private static class TestWorkflow extends SdkWorkflow { private final String name; private TestWorkflow(String name) { - super(SdkTypes.nulls()); + super(SdkTypes.nulls(), SdkTypes.nulls()); this.name = name; } diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowWithSdkRemoteLaunchPlanTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowWithSdkRemoteLaunchPlanTest.java index 476bd3f4e..e54768d1f 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowWithSdkRemoteLaunchPlanTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkWorkflowWithSdkRemoteLaunchPlanTest.java @@ -109,9 +109,10 @@ private List expectedOutputs() { .build()); } - public static class WorkflowExample extends SdkWorkflow { + public static class WorkflowExample + extends SdkWorkflow { public WorkflowExample() { - super(new TestUnaryBooleanOutput.SdkType()); + super(new TestPairIntegerInput.SdkType(), new TestUnaryBooleanOutput.SdkType()); } @Override @@ -121,7 +122,7 @@ public void expand(SdkWorkflowBuilder builder) { SdkNode node1 = builder.apply( - "some-node-id", new TestSdkRemoteLaunchPlan().withInput("a", a).withInput("b", b)); + "some-node-id", new TestSdkRemoteLaunchPlan(), TestPairIntegerInput.create(a, b)); builder.output("o", node1.getOutputs().o()); } diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SimpleSdkLaunchPlanTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SimpleSdkLaunchPlanTest.java index 4f28f6390..052392dc4 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SimpleSdkLaunchPlanTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SimpleSdkLaunchPlanTest.java @@ -27,7 +27,7 @@ import org.junit.jupiter.api.Test; public class SimpleSdkLaunchPlanTest { - private static final SdkWorkflow WORKFLOW = new TestWorkflow(); + private static final SdkWorkflow WORKFLOW = new TestWorkflow(); private static final SdkLaunchPlan LP = SdkLaunchPlan.of(WORKFLOW).withName("lp"); private static final SdkLaunchPlan LP2 = SdkLaunchPlan.of(WORKFLOW).withName("lp2"); @@ -86,9 +86,9 @@ public TestSimpleSdkLaunchPlanRegistryWithDuplicateNames() { } } - public static class TestWorkflow extends SdkWorkflow { + public static class TestWorkflow extends SdkWorkflow { public TestWorkflow() { - super(SdkTypes.nulls()); + super(SdkTypes.nulls(), SdkTypes.nulls()); } @Override @@ -97,9 +97,9 @@ public void expand(SdkWorkflowBuilder builder) { } } - private static class OtherTestWorkflow extends SdkWorkflow { + private static class OtherTestWorkflow extends SdkWorkflow { private OtherTestWorkflow() { - super(SdkTypes.nulls()); + super(SdkTypes.nulls(), SdkTypes.nulls()); } @Override diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerInput.java b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerInput.java new file mode 100644 index 000000000..db7bdbbe6 --- /dev/null +++ b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerInput.java @@ -0,0 +1,63 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import com.google.auto.value.AutoValue; +import java.util.Map; +import org.flyte.api.v1.Literal; +import org.flyte.api.v1.LiteralType; +import org.flyte.api.v1.Variable; + +@AutoValue +abstract class TestUnaryIntegerInput { + abstract SdkBindingData in(); + + public static TestUnaryIntegerInput create(SdkBindingData in) { + return new AutoValue_TestUnaryIntegerInput(in); + } + + public static class SdkType extends org.flyte.flytekit.SdkType { + + private static final String VAR = "in"; + private static final LiteralType LITERAL_TYPE = LiteralTypes.INTEGER; + + @Override + public Map toLiteralMap(TestUnaryIntegerInput value) { + return Map.of(VAR, Literals.ofInteger(value.in().get())); + } + + @Override + public TestUnaryIntegerInput fromLiteralMap(Map value) { + return create(SdkBindingData.ofInteger(value.get(VAR).scalar().primitive().integerValue())); + } + + @Override + public TestUnaryIntegerInput promiseFor(String nodeId) { + return create(SdkBindingData.ofOutputReference(nodeId, VAR, LITERAL_TYPE)); + } + + @Override + public Map getVariableMap() { + return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build()); + } + + @Override + public Map> toSdkBindingMap(TestUnaryIntegerInput value) { + return Map.of(VAR, value.in()); + } + } +} diff --git a/flytekit-local-engine/pom.xml b/flytekit-local-engine/pom.xml index 33adb3969..bc46e912c 100644 --- a/flytekit-local-engine/pom.xml +++ b/flytekit-local-engine/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT flytekit-local-engine diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/LocalEngineTest.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/LocalEngineTest.java index 2ff089db5..4ce27d35a 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/LocalEngineTest.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/LocalEngineTest.java @@ -48,7 +48,6 @@ import org.flyte.api.v1.WorkflowTemplateRegistrar; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -549,10 +548,11 @@ static > Map loadAll( @AutoService(SdkWorkflow.class) public static class TestCaseExhaustivenessWorkflow - extends SdkWorkflow { + extends SdkWorkflow< + TestCaseExhaustivenessWorkflow.NoOpType, TestCaseExhaustivenessWorkflow.NoOpType> { public TestCaseExhaustivenessWorkflow() { - super(JacksonSdkType.of(NoOpType.class)); + super(JacksonSdkType.of(NoOpType.class), JacksonSdkType.of(NoOpType.class)); } @Override @@ -562,8 +562,8 @@ public void expand(SdkWorkflowBuilder builder) { builder .apply( "decide", - when("eq_1", eq(ofInteger(1L), x), NoOp.of(x)) - .when("eq_2", eq(ofInteger(2L), x), NoOp.of(x))) + when("eq_1", eq(ofInteger(1L), x), new NoOp(), NoOpType.create(x)) + .when("eq_2", eq(ofInteger(2L), x), new NoOp(), NoOpType.create(x))) .getOutputs() .x(); @@ -582,10 +582,6 @@ public NoOp() { public NoOpType run(NoOpType input) { return NoOpType.create(input.x()); } - - static SdkTransform of(SdkBindingData x) { - return new NoOp().withInput("x", x); - } } @AutoValue diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/CollatzConjectureStepWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/CollatzConjectureStepWorkflow.java index 80859c22c..cd7cff942 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/CollatzConjectureStepWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/CollatzConjectureStepWorkflow.java @@ -30,27 +30,36 @@ // if x is even, then x/2 else 3x+1 @AutoService(SdkWorkflow.class) -public class CollatzConjectureStepWorkflow extends SdkWorkflow { +public class CollatzConjectureStepWorkflow + extends SdkWorkflow { public CollatzConjectureStepWorkflow() { - super(new TestUnaryIntegerOutput.SdkType()); + super( + JacksonSdkType.of(CollatzConjectureStepWorkflow.Input.class), + JacksonSdkType.of(TestUnaryIntegerOutput.class)); + } + + @AutoValue + abstract static class Input { + abstract SdkBindingData x(); + + public static Input create(SdkBindingData x) { + return new AutoValue_CollatzConjectureStepWorkflow_Input(x); + } } @Override public void expand(SdkWorkflowBuilder builder) { SdkBindingData x = builder.inputOfInteger("x"); SdkBindingData isOdd = - builder.apply("is_odd", new IsEvenTask().withInput("x", x)).getOutputs().res(); + builder.apply("is_odd", new IsEvenTask(), IsEvenTask.Input.create(x)).getOutputs().res(); SdkBindingData nextX = builder .apply( "decide", - when( - "was_even", - isTrue(isOdd), - new Divide().withInput("num", x).withInput("den", ofInteger(2L))) - .otherwise("was_odd", new ThreeXPlusOne().withInput("x", x))) + when("was_even", isTrue(isOdd), new Divide(), Divide.Input.create(x, ofInteger(2L))) + .otherwise("was_odd", new ThreeXPlusOne(), ThreeXPlusOne.Input.create(x))) .getOutputs() .o(); @@ -67,7 +76,7 @@ public IsEvenTask() { @Override public IsEvenTask.Output run(IsEvenTask.Input input) { - return IsEvenTask.Output.create(input.x().get() % 2 == 0); + return IsEvenTask.Output.create(SdkBindingData.ofBoolean(input.x().get() % 2 == 0)); } @AutoValue @@ -75,9 +84,8 @@ public abstract static class Input { public abstract SdkBindingData x(); - public static Input create(Long x) { - return new AutoValue_CollatzConjectureStepWorkflow_IsEvenTask_Input( - SdkBindingData.ofInteger(x)); + public static Input create(SdkBindingData x) { + return new AutoValue_CollatzConjectureStepWorkflow_IsEvenTask_Input(x); } } @@ -86,9 +94,8 @@ public abstract static class Output { public abstract SdkBindingData res(); - public static Output create(boolean res) { - return new AutoValue_CollatzConjectureStepWorkflow_IsEvenTask_Output( - SdkBindingData.ofBoolean(res)); + public static Output create(SdkBindingData res) { + return new AutoValue_CollatzConjectureStepWorkflow_IsEvenTask_Output(res); } } } @@ -98,7 +105,7 @@ public static class Divide extends SdkRunnableTask den(); - public static Input create(long num, long den) { - return new AutoValue_CollatzConjectureStepWorkflow_Divide_Input( - SdkBindingData.ofInteger(num), SdkBindingData.ofInteger(den)); + public static Input create(SdkBindingData num, SdkBindingData den) { + return new AutoValue_CollatzConjectureStepWorkflow_Divide_Input(num, den); } } @@ -124,9 +130,8 @@ public abstract static class Output { public abstract SdkBindingData res(); - public static Output create(long res) { - return new AutoValue_CollatzConjectureStepWorkflow_Divide_Output( - SdkBindingData.ofInteger(res)); + public static Output create(SdkBindingData res) { + return new AutoValue_CollatzConjectureStepWorkflow_Divide_Output(res); } } } @@ -138,7 +143,9 @@ public static class ThreeXPlusOne private static final long serialVersionUID = 932934331328064751L; public ThreeXPlusOne() { - super(JacksonSdkType.of(ThreeXPlusOne.Input.class), new TestUnaryIntegerOutput.SdkType()); + super( + JacksonSdkType.of(ThreeXPlusOne.Input.class), + JacksonSdkType.of(TestUnaryIntegerOutput.class)); } @Override @@ -150,9 +157,8 @@ public TestUnaryIntegerOutput run(ThreeXPlusOne.Input input) { public abstract static class Input { public abstract SdkBindingData x(); - public static Input create(long x) { - return new AutoValue_CollatzConjectureStepWorkflow_ThreeXPlusOne_Input( - SdkBindingData.ofInteger(x)); + public static Input create(SdkBindingData x) { + return new AutoValue_CollatzConjectureStepWorkflow_ThreeXPlusOne_Input(x); } } } diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/FibonacciWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/FibonacciWorkflow.java index da9ee444f..556c3d174 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/FibonacciWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/FibonacciWorkflow.java @@ -24,10 +24,13 @@ import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class FibonacciWorkflow extends SdkWorkflow { +public class FibonacciWorkflow + extends SdkWorkflow { public FibonacciWorkflow() { - super(JacksonSdkType.of(FibonacciWorkflow.Output.class)); + super( + JacksonSdkType.of(FibonacciWorkflow.Input.class), + JacksonSdkType.of(FibonacciWorkflow.Output.class)); } @Override @@ -36,33 +39,28 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData fib1 = builder.inputOfInteger("fib1"); SdkBindingData fib2 = - builder - .apply("fib-2", new SumTask().withInput("a", fib0).withInput("b", fib1)) - .getOutputs() - .o(); + builder.apply("fib-2", new SumTask(), SumTask.Input.create(fib0, fib1)).getOutputs().o(); SdkBindingData fib3 = - builder - .apply("fib-3", new SumTask().withInput("a", fib1).withInput("b", fib2)) - .getOutputs() - .o(); + builder.apply("fib-3", new SumTask(), SumTask.Input.create(fib1, fib2)).getOutputs().o(); SdkBindingData fib4 = - builder - .apply("fib-4", new SumTask().withInput("a", fib2).withInput("b", fib3)) - .getOutputs() - .o(); + builder.apply("fib-4", new SumTask(), SumTask.Input.create(fib2, fib3)).getOutputs().o(); SdkBindingData fib5 = - builder - .apply("fib-5", new SumTask().withInput("a", fib3).withInput("b", fib4)) - .getOutputs() - .o(); + builder.apply("fib-5", new SumTask(), SumTask.Input.create(fib3, fib4)).getOutputs().o(); builder.output("fib4", fib4); builder.output("fib5", fib5); } + @AutoValue + public abstract static class Input { + public abstract SdkBindingData fib0(); + + public abstract SdkBindingData fib1(); + } + @AutoValue public abstract static class Output { public abstract SdkBindingData fib4(); diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/InnerSubWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/InnerSubWorkflow.java index e68ef2fe6..3fc00c19d 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/InnerSubWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/InnerSubWorkflow.java @@ -23,9 +23,9 @@ import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class InnerSubWorkflow extends SdkWorkflow { +public class InnerSubWorkflow extends SdkWorkflow { public InnerSubWorkflow() { - super(JacksonSdkType.of(TestUnaryIntegerOutput.class)); + super(JacksonSdkType.of(SumTask.Input.class), JacksonSdkType.of(TestUnaryIntegerOutput.class)); } @Override @@ -33,10 +33,7 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData a = builder.inputOfInteger("a"); SdkBindingData b = builder.inputOfInteger("b"); SdkBindingData c = - builder - .apply("inner-sum-a-b", new SumTask().withInput("a", a).withInput("b", b)) - .getOutputs() - .o(); + builder.apply("inner-sum-a-b", new SumTask(), SumTask.Input.create(a, b)).getOutputs().o(); builder.output("o", c); } diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListTask.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListTask.java index ab43adfb0..1f1dd0c87 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListTask.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListTask.java @@ -33,15 +33,15 @@ public ListTask() { @Override public Output run(Input input) { - return Output.create(input.list().get()); + return Output.create(SdkBindingData.ofIntegerCollection(input.list().get())); } @AutoValue public abstract static class Input { public abstract SdkBindingData> list(); - public static Input create(List list) { - return new AutoValue_ListTask_Input(SdkBindingData.ofIntegerCollection(list)); + public static Input create(SdkBindingData> list) { + return new AutoValue_ListTask_Input(list); } } @@ -49,8 +49,8 @@ public static Input create(List list) { public abstract static class Output { public abstract SdkBindingData> list(); - public static Output create(List list) { - return new AutoValue_ListTask_Output(SdkBindingData.ofIntegerCollection(list)); + public static Output create(SdkBindingData> list) { + return new AutoValue_ListTask_Output(list); } } } diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListWorkflow.java index 7f29050a9..44fb97f15 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/ListWorkflow.java @@ -16,37 +16,39 @@ */ package org.flyte.localengine.examples; +import static org.flyte.flytekit.SdkBindingData.ofInteger; + import com.google.auto.service.AutoService; import java.util.List; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkNode; +import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; -import org.flyte.localengine.ImmutableList; @AutoService(SdkWorkflow.class) -public class ListWorkflow extends SdkWorkflow { +public class ListWorkflow extends SdkWorkflow { public ListWorkflow() { - super(JacksonSdkType.of(ListTask.Output.class)); + super(SdkTypes.nulls(), JacksonSdkType.of(ListTask.Output.class)); } @Override public void expand(SdkWorkflowBuilder builder) { SdkNode sum1 = - builder.apply("sum-1", new SumTask().withInput("a", 1).withInput("b", 2)); + builder.apply("sum-1", new SumTask(), SumTask.Input.create(ofInteger(1), ofInteger(2))); SdkNode sum2 = - builder.apply("sum-2", new SumTask().withInput("a", 3).withInput("b", 4)); + builder.apply("sum-2", new SumTask(), SumTask.Input.create(ofInteger(3), ofInteger(4))); SdkBindingData> list = SdkBindingData.ofBindingCollection( LiteralType.ofCollectionType(LiteralType.ofSimpleType(SimpleType.INTEGER)), - ImmutableList.of(sum1.getOutputs().o(), sum2.getOutputs().o())); + List.of(sum1.getOutputs().o(), sum2.getOutputs().o())); SdkNode list1 = - builder.apply("list-1", new ListTask().withInput("list", list)); + builder.apply("list-1", new ListTask(), ListTask.Input.create(list)); builder.output("list", list1.getOutputs().list()); } diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapTask.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapTask.java index 15f8d4151..47d76801f 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapTask.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapTask.java @@ -40,8 +40,8 @@ public Output run(Input input) { public abstract static class Input { public abstract SdkBindingData> map(); - public static Input create(Map map) { - return new AutoValue_MapTask_Input(SdkBindingData.ofIntegerMap(map)); + public static Input create(SdkBindingData> map) { + return new AutoValue_MapTask_Input(map); } } diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapWorkflow.java index 3aa08d272..7642b80f0 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/MapWorkflow.java @@ -16,6 +16,8 @@ */ package org.flyte.localengine.examples; +import static org.flyte.flytekit.SdkBindingData.ofInteger; + import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import java.util.Map; @@ -23,14 +25,15 @@ import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkNode; +import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class MapWorkflow extends SdkWorkflow { +public class MapWorkflow extends SdkWorkflow { public MapWorkflow() { - super(JacksonSdkType.of(MapWorkflow.Output.class)); + super(SdkTypes.nulls(), JacksonSdkType.of(MapWorkflow.Output.class)); } @AutoValue @@ -46,17 +49,23 @@ public static MapWorkflow.Output create(Map map) { @Override public void expand(SdkWorkflowBuilder builder) { SdkBindingData sum1 = - builder.apply("sum-1", new SumTask().withInput("a", 1).withInput("b", 2)).getOutputs().o(); + builder + .apply("sum-1", new SumTask(), SumTask.Input.create(ofInteger(1), ofInteger(2))) + .getOutputs() + .o(); SdkBindingData sum2 = - builder.apply("sum-2", new SumTask().withInput("a", 3).withInput("b", 4)).getOutputs().o(); + builder + .apply("sum-2", new SumTask(), SumTask.Input.create(ofInteger(3), ofInteger(4))) + .getOutputs() + .o(); SdkBindingData> map = SdkBindingData.ofBindingMap( LiteralType.ofMapValueType(LiteralType.ofSimpleType(SimpleType.INTEGER)), Map.of("e", sum1, "f", sum2)); - SdkNode map1 = builder.apply("map-1", new MapTask().withInput("map", map)); + SdkNode map1 = builder.apply("map-1", new MapTask(), MapTask.Input.create(map)); builder.output("map", map1.getOutputs().map()); } diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/NestedSubWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/NestedSubWorkflow.java index bfc760ee3..d9eae6d0f 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/NestedSubWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/NestedSubWorkflow.java @@ -23,10 +23,12 @@ import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class NestedSubWorkflow extends SdkWorkflow { +public class NestedSubWorkflow extends SdkWorkflow { public NestedSubWorkflow() { - super(JacksonSdkType.of(TestUnaryIntegerOutput.class)); + super( + JacksonSdkType.of(TestTuple3IntegerInput.class), + JacksonSdkType.of(TestUnaryIntegerOutput.class)); } @Override @@ -37,8 +39,7 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData result = builder .apply( - "nested-workflow", - new OuterSubWorkflow().withInput("a", a).withInput("b", b).withInput("c", c)) + "nested-workflow", new OuterSubWorkflow(), TestTuple3IntegerInput.create(a, b, c)) .getOutputs() .o(); builder.output("o", result); diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/OuterSubWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/OuterSubWorkflow.java index 55722b00f..c5eb4d544 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/OuterSubWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/OuterSubWorkflow.java @@ -21,12 +21,15 @@ import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; +import org.flyte.localengine.examples.SumTask.Input; @AutoService(SdkWorkflow.class) -public class OuterSubWorkflow extends SdkWorkflow { +public class OuterSubWorkflow extends SdkWorkflow { public OuterSubWorkflow() { - super(JacksonSdkType.of(TestUnaryIntegerOutput.class)); + super( + JacksonSdkType.of(TestTuple3IntegerInput.class), + JacksonSdkType.of(TestUnaryIntegerOutput.class)); } @Override @@ -35,13 +38,10 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData b = builder.inputOfInteger("b"); SdkBindingData c = builder.inputOfInteger("c"); SdkBindingData ab = - builder - .apply("outer-sum-a-b", new SumTask().withInput("a", a).withInput("b", b)) - .getOutputs() - .o(); + builder.apply("outer-sum-a-b", new SumTask(), Input.create(a, b)).getOutputs().o(); SdkBindingData res = builder - .apply("outer-sum-ab-c", new InnerSubWorkflow().withInput("a", ab).withInput("b", c)) + .apply("outer-sum-ab-c", new InnerSubWorkflow(), SumTask.Input.create(ab, c)) .getOutputs() .o(); builder.output("o", res); diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/RetryableWorkflow.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/RetryableWorkflow.java index 27317e5a4..9f1d4583a 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/RetryableWorkflow.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/RetryableWorkflow.java @@ -22,9 +22,9 @@ import org.flyte.flytekit.SdkWorkflowBuilder; @AutoService(SdkWorkflow.class) -public class RetryableWorkflow extends SdkWorkflow { +public class RetryableWorkflow extends SdkWorkflow { public RetryableWorkflow() { - super(SdkTypes.nulls()); + super(SdkTypes.nulls(), SdkTypes.nulls()); } @Override diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/SumTask.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/SumTask.java index 54c614c1d..7f37f1409 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/SumTask.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/SumTask.java @@ -21,29 +21,29 @@ import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.jackson.JacksonSdkType; +import org.flyte.localengine.examples.SumTask.Input; @AutoService(SdkRunnableTask.class) -public class SumTask extends SdkRunnableTask { +public class SumTask extends SdkRunnableTask { private static final long serialVersionUID = -7796919693971619417L; public SumTask() { - super(JacksonSdkType.of(SumInput.class), new TestUnaryIntegerOutput.SdkType()); + super(JacksonSdkType.of(Input.class), JacksonSdkType.of(TestUnaryIntegerOutput.class)); } @AutoValue - public abstract static class SumInput { + public abstract static class Input { public abstract SdkBindingData a(); public abstract SdkBindingData b(); - public static SumInput create(long a, long b) { - return new AutoValue_SumTask_SumInput( - SdkBindingData.ofInteger(a), SdkBindingData.ofInteger(b)); + public static Input create(SdkBindingData a, SdkBindingData b) { + return new AutoValue_SumTask_Input(a, b); } } @Override - public TestUnaryIntegerOutput run(SumInput input) { + public TestUnaryIntegerOutput run(Input input) { return TestUnaryIntegerOutput.create( SdkBindingData.ofInteger(input.a().get() + input.b().get())); } diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestTuple3IntegerInput.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestTuple3IntegerInput.java new file mode 100644 index 000000000..f88a43601 --- /dev/null +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestTuple3IntegerInput.java @@ -0,0 +1,34 @@ +/* + * Copyright 2021 Flyte Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.localengine.examples; + +import com.google.auto.value.AutoValue; +import org.flyte.flytekit.SdkBindingData; + +@AutoValue +public abstract class TestTuple3IntegerInput { + public abstract SdkBindingData a(); + + public abstract SdkBindingData b(); + + public abstract SdkBindingData c(); + + public static TestTuple3IntegerInput create( + SdkBindingData a, SdkBindingData b, SdkBindingData c) { + return new AutoValue_TestTuple3IntegerInput(a, b, c); + } +} diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestUnaryIntegerOutput.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestUnaryIntegerOutput.java index 337ff472e..cd7ed92f4 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestUnaryIntegerOutput.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestUnaryIntegerOutput.java @@ -17,13 +17,6 @@ package org.flyte.localengine.examples; import com.google.auto.value.AutoValue; -import java.util.Map; -import org.flyte.api.v1.Literal; -import org.flyte.api.v1.LiteralType; -import org.flyte.api.v1.Primitive; -import org.flyte.api.v1.Scalar; -import org.flyte.api.v1.SimpleType; -import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; @AutoValue @@ -33,36 +26,4 @@ public abstract class TestUnaryIntegerOutput { public static TestUnaryIntegerOutput create(SdkBindingData o) { return new AutoValue_TestUnaryIntegerOutput(o); } - - public static class SdkType extends org.flyte.flytekit.SdkType { - - private static final String VAR = "o"; - private static final LiteralType LITERAL_TYPE = LiteralType.ofSimpleType(SimpleType.INTEGER); - - @Override - public Map toLiteralMap(TestUnaryIntegerOutput value) { - return Map.of( - VAR, Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(value.o().get())))); - } - - @Override - public TestUnaryIntegerOutput fromLiteralMap(Map value) { - return create(SdkBindingData.ofInteger(value.get(VAR).scalar().primitive().integerValue())); - } - - @Override - public TestUnaryIntegerOutput promiseFor(String nodeId) { - return create(SdkBindingData.ofOutputReference(nodeId, VAR, LITERAL_TYPE)); - } - - @Override - public Map getVariableMap() { - return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build()); - } - - @Override - public Map> toSdkBindingMap(TestUnaryIntegerOutput value) { - return Map.of(VAR, value.o()); - } - } } diff --git a/flytekit-scala-tests/pom.xml b/flytekit-scala-tests/pom.xml index b2e1e1fdf..efb1acb9c 100644 --- a/flytekit-scala-tests/pom.xml +++ b/flytekit-scala-tests/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT diff --git a/flytekit-scala_2.12/pom.xml b/flytekit-scala_2.12/pom.xml index 1bd9ed537..812d1c6f0 100644 --- a/flytekit-scala_2.12/pom.xml +++ b/flytekit-scala_2.12/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT diff --git a/flytekit-scala_2.13/pom.xml b/flytekit-scala_2.13/pom.xml index f08ae58f8..2af291675 100644 --- a/flytekit-scala_2.13/pom.xml +++ b/flytekit-scala_2.13/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index 82f459d72..87ad0d867 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -90,7 +90,7 @@ object SdkScalaType { param.label -> variable }.toMap - new ju.HashMap(mapAsJavaMap(scalaMap)) + ju.Map.copyOf(mapAsJavaMap(scalaMap)) } def toLiteralMap(value: T): ju.Map[String, Literal] = { @@ -98,7 +98,7 @@ object SdkScalaType { param.label -> param.typeclass.toLiteral(param.dereference(value)) }.toMap - new ju.HashMap(mapAsJavaMap(scalaMap)) + ju.Map.copyOf(mapAsJavaMap(scalaMap)) } def fromLiteralMap(literal: ju.Map[String, Literal]): T = { diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaWorkflow.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaWorkflow.scala index d5a33f592..048104f7e 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaWorkflow.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaWorkflow.scala @@ -29,8 +29,10 @@ import org.flyte.flytekit.{ import java.time.{Duration, Instant} import scala.collection.JavaConverters._ -abstract class SdkScalaWorkflow[T](outputType: SdkType[T]) - extends SdkWorkflow[T](outputType) { +abstract class SdkScalaWorkflow[InputT, OutputT]( + inputType: SdkType[InputT], + outputType: SdkType[OutputT] +) extends SdkWorkflow[InputT, OutputT](inputType, outputType) { final override def expand(builder: SdkWorkflowBuilder): Unit = { expand(new SdkScalaWorkflowBuilder(builder)) } @@ -104,26 +106,33 @@ class SdkScalaWorkflowBuilder(builder: SdkWorkflowBuilder) { def getOutputDescription(name: String): String = builder.getOutputDescription(name) - def output(name: String, value: SdkJavaBindingData[_], help: String = "") = + def output( + name: String, + value: SdkJavaBindingData[_], + help: String = "" + ): Unit = builder.output(name, value, help) def toIdlTemplate: WorkflowTemplate = builder.toIdlTemplate - def apply[T](nodeId: String, transform: SdkTransform[T]): SdkNode[T] = - builder.apply(nodeId, transform) + def apply[OutputT]( + nodeId: String, + transform: SdkTransform[Unit, OutputT] + ): SdkNode[OutputT] = + builder.apply(nodeId, transform, ()) - def apply[T]( + def apply[InputT, OutputT]( nodeId: String, - transform: SdkTransform[T], - inputs: Map[String, SdkJavaBindingData[_]] - ): SdkNode[T] = builder.apply(nodeId, transform, inputs.asJava) + transform: SdkTransform[InputT, OutputT], + inputs: InputT + ): SdkNode[OutputT] = builder.apply(nodeId, transform, inputs) - def apply[T](transform: SdkTransform[T]): SdkNode[T] = - builder.apply(transform) + def apply[OutputT](transform: SdkTransform[Unit, OutputT]): SdkNode[OutputT] = + builder.apply(transform, ()) - def apply[T]( - transform: SdkTransform[T], - inputs: Map[String, SdkJavaBindingData[_]] - ): SdkNode[T] = builder.apply(transform, inputs.asJava) + def apply[InputT, OutputT]( + transform: SdkTransform[InputT, OutputT], + inputs: InputT + ): SdkNode[OutputT] = builder.apply(transform, inputs) } diff --git a/flytekit-testing/pom.xml b/flytekit-testing/pom.xml index fb5413c1b..d15b7c151 100644 --- a/flytekit-testing/pom.xml +++ b/flytekit-testing/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT flytekit-testing diff --git a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java index 8c5e11cb1..f9c606f0b 100644 --- a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java +++ b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java @@ -29,6 +29,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.ServiceLoader; import java.util.function.Function; import org.flyte.api.v1.Literal; @@ -59,40 +60,42 @@ public abstract class SdkTestingExecutor { abstract Map> launchPlanTestDoubles(); - abstract SdkWorkflow workflow(); + abstract SdkWorkflow workflow(); abstract Map workflowTemplates(); @SuppressWarnings({"unchecked", "rawtypes"}) - public static SdkTestingExecutor of(SdkWorkflow workflow) { + public static SdkTestingExecutor of(SdkWorkflow workflow) { ServiceLoader> taskLoader = (ServiceLoader) ServiceLoader.load(SdkRunnableTask.class); List> tasks = new ArrayList<>(); taskLoader.iterator().forEachRemaining(tasks::add); - ServiceLoader> workflowLoader = + ServiceLoader> workflowLoader = (ServiceLoader) ServiceLoader.load(SdkWorkflow.class); - List> workflows = new ArrayList<>(); + List> workflows = new ArrayList<>(); workflowLoader.iterator().forEachRemaining(workflows::add); return SdkTestingExecutor.of(workflow, tasks, workflows); } @Deprecated - public static SdkTestingExecutor of( - SdkWorkflow workflow, List> tasks) { + public static SdkTestingExecutor of( + SdkWorkflow workflow, List> tasks) { @SuppressWarnings({"unchecked", "rawtypes"}) - ServiceLoader> workflowLoader = + ServiceLoader> workflowLoader = (ServiceLoader) ServiceLoader.load(SdkWorkflow.class); - List> workflows = new ArrayList<>(); + List> workflows = new ArrayList<>(); workflowLoader.iterator().forEachRemaining(workflows::add); return SdkTestingExecutor.of(workflow, tasks, workflows); } public static SdkTestingExecutor of( - SdkWorkflow workflow, List> tasks, List> workflows) { + SdkWorkflow workflow, + List> tasks, + List> workflows) { Map> fixedTasks = tasks.stream().collect(toMap(SdkRunnableTask::getName, TestingRunnableTask::create)); @@ -356,7 +359,7 @@ public SdkTestingExecutor withTask( } public SdkTestingExecutor withWorkflowOutput( - SdkWorkflow workflow, + SdkWorkflow workflow, SdkType inputType, InputT input, SdkType outputType, @@ -368,7 +371,8 @@ public SdkTestingExecutor withWorkflowOutput( getFixedTaskOrDefault(workflow.getName(), inputType, outputType); // replace workflow - SdkWorkflow mockWorkflow = new TestingWorkflow<>(inputType, outputType, output); + SdkWorkflow mockWorkflow = + new TestingWorkflow<>(inputType, outputType, output); return toBuilder() .putWorkflowTemplate(workflow.getName(), mockWorkflow.toIdlTemplate()) @@ -377,7 +381,9 @@ public SdkTestingExecutor withWorkflowOutput( } private static void verifyInputOutputMatchesWorkflowInterface( - SdkWorkflow workflow, SdkType inputType, SdkType outputType) { + SdkWorkflow workflow, + SdkType inputType, + SdkType outputType) { TypedInterface intf = workflow.toIdlTemplate().interface_(); verifyVariablesMatches("Input", intf.inputs(), inputType.getVariableMap()); @@ -402,11 +408,8 @@ private TestingRunnableTask getFixedTaskOrDef TestingRunnableTask fixedTask = (TestingRunnableTask) taskTestDoubles().get(name); - if (fixedTask == null) { - return TestingRunnableTask.create(name, inputType, outputType); - } else { - return fixedTask; - } + return Objects.requireNonNullElseGet( + fixedTask, () -> TestingRunnableTask.create(name, inputType, outputType)); } private @@ -416,11 +419,8 @@ TestingRunnableLaunchPlan getRunnableLaunchPlanOrDefault( TestingRunnableLaunchPlan launchPlantTestDouble = (TestingRunnableLaunchPlan) launchPlanTestDoubles().get(name); - if (launchPlantTestDouble == null) { - return TestingRunnableLaunchPlan.create(name, inputType, outputType); - } else { - return launchPlantTestDouble; - } + return Objects.requireNonNullElseGet( + launchPlantTestDouble, () -> TestingRunnableLaunchPlan.create(name, inputType, outputType)); } abstract Builder toBuilder(); @@ -437,7 +437,7 @@ abstract static class Builder { abstract Builder taskTestDoubles(Map> taskTestDoubles); - abstract Builder workflow(SdkWorkflow workflow); + abstract Builder workflow(SdkWorkflow workflow); abstract Builder launchPlanTestDoubles( Map> launchPlanTestDoubles); diff --git a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java index 6bc3fca76..58809f592 100644 --- a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java +++ b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java @@ -26,7 +26,7 @@ import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; -class TestingWorkflow extends SdkWorkflow { +class TestingWorkflow extends SdkWorkflow { private final SdkType inputType; private final SdkType outputType; @@ -34,7 +34,7 @@ class TestingWorkflow extends SdkWorkflow { private final Map outputLiterals; TestingWorkflow(SdkType inputType, SdkType outputType, OutputT output) { - super(outputType); + super(inputType, outputType); this.inputType = inputType; this.outputType = outputType; this.output = output; diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/FibonacciWorkflowTest.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/FibonacciWorkflowTest.java index 695883c03..9ed103492 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/FibonacciWorkflowTest.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/FibonacciWorkflowTest.java @@ -16,12 +16,12 @@ */ package org.flyte.flytekit.testing; +import static org.flyte.flytekit.SdkBindingData.ofInteger; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; -import org.flyte.flytekit.SdkNode; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -53,8 +53,7 @@ public void testWithFixedInputs() { SdkTestingExecutor.of(new FibonacciWorkflow()) .withFixedInputs( JacksonSdkType.of(FibonacciWorkflowInputs.class), - FibonacciWorkflowInputs.create( - SdkBindingData.ofInteger(1), SdkBindingData.ofInteger(1))) + FibonacciWorkflowInputs.create(ofInteger(1), ofInteger(1))) .execute(); assertThat(result.getIntegerOutput("fib2"), equalTo(2L)); @@ -69,7 +68,10 @@ public void testWithTaskOutput_runnableTask() { SdkTestingExecutor.of(new FibonacciWorkflow()) .withFixedInput("fib0", 1) .withFixedInput("fib1", 1) - .withTaskOutput(new SumTask(), SumInput.create(3L, 5L), SumOutput.create(42L)) + .withTaskOutput( + new SumTask(), + SumInput.create(ofInteger(3L), ofInteger(5L)), + SumOutput.create(ofInteger(42L))) .execute(); assertThat(result.getIntegerOutput("fib2"), equalTo(2L)); @@ -85,14 +87,20 @@ public void testWithTaskOutput_remoteTask() { .withFixedInput("fib0", 1) .withFixedInput("fib1", 1) .withTaskOutput( - RemoteSumTask.create(), RemoteSumInput.create(1L, 1L), RemoteSumOutput.create(5L)) + RemoteSumTask.create(), + RemoteSumInput.create(ofInteger(1L), ofInteger(1L)), + RemoteSumOutput.create(5L)) .withTaskOutput( - RemoteSumTask.create(), RemoteSumInput.create(1L, 5L), RemoteSumOutput.create(10L)) + RemoteSumTask.create(), + RemoteSumInput.create(ofInteger(1L), ofInteger(5L)), + RemoteSumOutput.create(10L)) .withTaskOutput( - RemoteSumTask.create(), RemoteSumInput.create(5L, 10L), RemoteSumOutput.create(20L)) + RemoteSumTask.create(), + RemoteSumInput.create(ofInteger(5L), ofInteger(10L)), + RemoteSumOutput.create(20L)) .withTaskOutput( RemoteSumTask.create(), - RemoteSumInput.create(10L, 20L), + RemoteSumInput.create(ofInteger(10L), ofInteger(20L)), RemoteSumOutput.create(40L)) .execute(); @@ -108,9 +116,14 @@ public void testWithTask() { SdkTestingExecutor.of(new FibonacciWorkflow()) .withFixedInput("fib0", 1) .withFixedInput("fib1", 1) - .withTask(new SumTask(), input -> SumOutput.create(input.a().get() * input.b().get())) + .withTask( + new SumTask(), + input -> SumOutput.create(ofInteger(input.a().get() * input.b().get()))) // can combine withTask and withTaskOutput - .withTaskOutput(new SumTask(), SumInput.create(1, 1), SumOutput.create(2)) + .withTaskOutput( + new SumTask(), + SumInput.create(ofInteger(1), ofInteger(1)), + SumOutput.create(ofInteger(2))) .execute(); assertThat(result.getIntegerOutput("fib2"), equalTo(2L)); @@ -119,9 +132,12 @@ public void testWithTask() { assertThat(result.getIntegerOutput("fib5"), equalTo(8L)); } - public static class FibonacciWorkflow extends SdkWorkflow { + public static class FibonacciWorkflow + extends SdkWorkflow { public FibonacciWorkflow() { - super(JacksonSdkType.of(FibonacciWorkflowOutputs.class)); + super( + JacksonSdkType.of(FibonacciWorkflowInputs.class), + JacksonSdkType.of(FibonacciWorkflowOutputs.class)); } @Override @@ -129,31 +145,22 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData fib0 = builder.inputOfInteger("fib0"); SdkBindingData fib1 = builder.inputOfInteger("fib1"); - SdkNode fib2 = - builder.apply("fib-2", new SumTask().withInput("a", fib0).withInput("b", fib1)); - - SdkNode fib3 = - builder.apply( - "fib-3", new SumTask().withInput("a", fib1).withInput("b", fib2.getOutput("c"))); - - SdkNode fib4 = - builder.apply( - "fib-4", - new SumTask() - .withInput("a", fib2.getOutput("c")) - .withInput("b", fib3.getOutput("c"))); - - SdkNode fib5 = - builder.apply( - "fib-5", - new SumTask() - .withInput("a", fib3.getOutput("c")) - .withInput("b", fib4.getOutput("c"))); - - builder.output("fib2", fib2.getOutput("c")); - builder.output("fib3", fib3.getOutput("c")); - builder.output("fib4", fib4.getOutput("c")); - builder.output("fib5", fib5.getOutput("c")); + SdkBindingData fib2 = + builder.apply("fib-2", new SumTask(), SumInput.create(fib0, fib1)).getOutputs().c(); + + SdkBindingData fib3 = + builder.apply("fib-3", new SumTask(), SumInput.create(fib1, fib2)).getOutputs().c(); + + SdkBindingData fib4 = + builder.apply("fib-4", new SumTask(), SumInput.create(fib2, fib3)).getOutputs().c(); + + SdkBindingData fib5 = + builder.apply("fib-5", new SumTask(), SumInput.create(fib3, fib4)).getOutputs().c(); + + builder.output("fib2", fib2); + builder.output("fib3", fib3); + builder.output("fib4", fib4); + builder.output("fib5", fib5); } } @@ -181,9 +188,12 @@ public abstract static class FibonacciWorkflowOutputs { } /** FibonacciWorkflow, but using RemoteSumTask instead. */ - public static class RemoteFibonacciWorkflow extends SdkWorkflow { + public static class RemoteFibonacciWorkflow + extends SdkWorkflow { public RemoteFibonacciWorkflow() { - super(JacksonSdkType.of(FibonacciWorkflowOutputs.class)); + super( + JacksonSdkType.of(FibonacciWorkflowInputs.class), + JacksonSdkType.of(FibonacciWorkflowOutputs.class)); } @Override @@ -191,32 +201,34 @@ public void expand(SdkWorkflowBuilder builder) { SdkBindingData fib0 = builder.inputOfInteger("fib0"); SdkBindingData fib1 = builder.inputOfInteger("fib1"); - SdkNode fib2 = - builder.apply("fib-2", RemoteSumTask.create().withInput("a", fib0).withInput("b", fib1)); - - SdkNode fib3 = - builder.apply( - "fib-3", - RemoteSumTask.create().withInput("a", fib1).withInput("b", fib2.getOutput("c"))); - - SdkNode fib4 = - builder.apply( - "fib-4", - RemoteSumTask.create() - .withInput("a", fib2.getOutput("c")) - .withInput("b", fib3.getOutput("c"))); - - SdkNode fib5 = - builder.apply( - "fib-5", - RemoteSumTask.create() - .withInput("a", fib3.getOutput("c")) - .withInput("b", fib4.getOutput("c"))); - - builder.output("fib2", fib2.getOutput("c")); - builder.output("fib3", fib3.getOutput("c")); - builder.output("fib4", fib4.getOutput("c")); - builder.output("fib5", fib5.getOutput("c")); + SdkBindingData fib2 = + builder + .apply("fib-2", RemoteSumTask.create(), RemoteSumInput.create(fib0, fib1)) + .getOutputs() + .c(); + + SdkBindingData fib3 = + builder + .apply("fib-3", RemoteSumTask.create(), RemoteSumInput.create(fib1, fib2)) + .getOutputs() + .c(); + + SdkBindingData fib4 = + builder + .apply("fib-4", RemoteSumTask.create(), RemoteSumInput.create(fib2, fib3)) + .getOutputs() + .c(); + + SdkBindingData fib5 = + builder + .apply("fib-5", RemoteSumTask.create(), RemoteSumInput.create(fib3, fib4)) + .getOutputs() + .c(); + + builder.output("fib2", fib2); + builder.output("fib3", fib3); + builder.output("fib4", fib4); + builder.output("fib5", fib5); } } } diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/IfElseWorkflowTest.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/IfElseWorkflowTest.java index e95b7bd78..beba7c23e 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/IfElseWorkflowTest.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/IfElseWorkflowTest.java @@ -16,6 +16,7 @@ */ package org.flyte.flytekit.testing; +import static org.flyte.flytekit.SdkBindingData.ofString; import static org.flyte.flytekit.SdkConditions.eq; import static org.flyte.flytekit.SdkConditions.gt; import static org.flyte.flytekit.SdkConditions.lt; @@ -28,7 +29,6 @@ import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkCondition; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; @@ -70,9 +70,12 @@ public static Stream testCases() { Arguments.of(2, 1, 4, 3, "a > b && c > d")); } - static class BranchNodeWorkflow extends SdkWorkflow { + static class BranchNodeWorkflow + extends SdkWorkflow { BranchNodeWorkflow() { - super(JacksonSdkType.of(ConstStringTask.Output.class)); + super( + JacksonSdkType.of(ConstStringTask.Input.class), + JacksonSdkType.of(ConstStringTask.Output.class)); } @Override @@ -86,21 +89,57 @@ public void expand(SdkWorkflowBuilder builder) { when( "a == b", eq(a, b), - when("c == d", eq(c, d), ConstStringTask.of("a == b && c == d")) - .when("c > d", gt(c, d), ConstStringTask.of("a == b && c > d")) - .when("c < d", lt(c, d), ConstStringTask.of("a == b && c < d"))) + when( + "c == d", + eq(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(ofString("a == b && c == d"))) + .when( + "c > d", + gt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(ofString("a == b && c > d"))) + .when( + "c < d", + lt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(ofString("a == b && c < d")))) .when( "a < b", lt(a, b), - when("c == d", eq(c, d), ConstStringTask.of("a < b && c == d")) - .when("c > d", gt(c, d), ConstStringTask.of("a < b && c > d")) - .when("c < d", lt(c, d), ConstStringTask.of("a < b && c < d"))) + when( + "c == d", + eq(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(ofString("a < b && c == d"))) + .when( + "c > d", + gt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(ofString("a < b && c > d"))) + .when( + "c < d", + lt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(ofString("a < b && c < d")))) .when( "a > b", gt(a, b), - when("c == d", eq(c, d), ConstStringTask.of("a > b && c == d")) - .when("c > d", gt(c, d), ConstStringTask.of("a > b && c > d")) - .when("c < d", lt(c, d), ConstStringTask.of("a > b && c < d"))); + when( + "c == d", + eq(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(ofString("a > b && c == d"))) + .when( + "c > d", + gt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(ofString("a > b && c > d"))) + .when( + "c < d", + lt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(ofString("a > b && c < d")))); SdkBindingData value = builder.apply("condition", condition).getOutputs().value(); @@ -115,6 +154,10 @@ static class ConstStringTask @AutoValue abstract static class Input { abstract SdkBindingData value(); + + public static Input create(SdkBindingData value) { + return new AutoValue_IfElseWorkflowTest_ConstStringTask_Input(value); + } } @AutoValue @@ -122,8 +165,7 @@ abstract static class Output { abstract SdkBindingData value(); public static Output create(String value) { - return new AutoValue_IfElseWorkflowTest_ConstStringTask_Output( - SdkBindingData.ofString(value)); + return new AutoValue_IfElseWorkflowTest_ConstStringTask_Output(ofString(value)); } } @@ -131,10 +173,6 @@ public ConstStringTask() { super(JacksonSdkType.of(Input.class), JacksonSdkType.of(Output.class)); } - public static SdkTransform of(String value) { - return new ConstStringTask().withInput("value", SdkBindingData.ofString(value)); - } - @Override public Output run(Input input) { return Output.create(input.value().get()); diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteSumTask.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteSumTask.java index 778503326..5479210aa 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteSumTask.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/RemoteSumTask.java @@ -38,9 +38,8 @@ public abstract static class RemoteSumInput { public abstract SdkBindingData b(); - public static RemoteSumInput create(long a, long b) { - return new AutoValue_RemoteSumTask_RemoteSumInput( - SdkBindingData.ofInteger(a), SdkBindingData.ofInteger(b)); + public static RemoteSumInput create(SdkBindingData a, SdkBindingData b) { + return new AutoValue_RemoteSumTask_RemoteSumInput(a, b); } } diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SdkTestingExecutorTest.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SdkTestingExecutorTest.java index 598aa3607..0420af2fb 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SdkTestingExecutorTest.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SdkTestingExecutorTest.java @@ -37,7 +37,7 @@ public class SdkTestingExecutorTest { @AutoValue - public abstract static class TestWorkflowOutput { + public abstract static class TestWorkflowIO { public abstract SdkBindingData b(); @@ -50,32 +50,41 @@ public abstract static class TestWorkflowOutput { public abstract SdkBindingData i(); public abstract SdkBindingData s(); + + public static TestWorkflowIO create( + SdkBindingData b, + SdkBindingData datetime, + SdkBindingData duration, + SdkBindingData f, + SdkBindingData i, + SdkBindingData s) { + return new AutoValue_SdkTestingExecutorTest_TestWorkflowIO(b, datetime, duration, f, i, s); + } } @AutoValue - public abstract static class TestUnaryStringOutput { + public abstract static class TestUnaryStringIO { public abstract SdkBindingData string(); - public static TestUnaryStringOutput create(String string) { - return new AutoValue_SdkTestingExecutorTest_TestUnaryStringOutput( - SdkBindingData.ofString(string)); + public static TestUnaryStringIO create(SdkBindingData string) { + return new AutoValue_SdkTestingExecutorTest_TestUnaryStringIO(string); } } @AutoValue - public abstract static class TestUnaryIntegerOutput { - public abstract SdkBindingData o(); + public abstract static class TestUnaryIntegerIO { + public abstract SdkBindingData integer(); - public static TestUnaryIntegerOutput create(long l) { - return new AutoValue_SdkTestingExecutorTest_TestUnaryIntegerOutput( - SdkBindingData.ofInteger(l)); + public static TestUnaryIntegerIO create(SdkBindingData integer) { + return new AutoValue_SdkTestingExecutorTest_TestUnaryIntegerIO(integer); } } @Test public void testPrimitiveTypes() { - SdkWorkflow workflow = - new SdkWorkflow<>(JacksonSdkType.of(TestWorkflowOutput.class)) { + SdkWorkflow workflow = + new SdkWorkflow<>( + JacksonSdkType.of(TestWorkflowIO.class), JacksonSdkType.of(TestWorkflowIO.class)) { @Override public void expand(SdkWorkflowBuilder builder) { builder.output("b", builder.inputOfBoolean("b")); @@ -107,8 +116,8 @@ public void expand(SdkWorkflowBuilder builder) { @Test public void testGetOutput_doesntExist() { - SdkWorkflow workflow = - new SdkWorkflow<>(SdkTypes.nulls()) { + SdkWorkflow workflow = + new SdkWorkflow<>(JacksonSdkType.of(TestUnaryIntegerIO.class), SdkTypes.nulls()) { @Override public void expand(SdkWorkflowBuilder builder) { builder.output("integer", builder.inputOfInteger("integer")); @@ -126,8 +135,10 @@ public void expand(SdkWorkflowBuilder builder) { @Test public void testGetOutput_illegalType() { - SdkWorkflow workflow = - new SdkWorkflow<>(JacksonSdkType.of(TestUnaryStringOutput.class)) { + SdkWorkflow workflow = + new SdkWorkflow<>( + JacksonSdkType.of(TestUnaryStringIO.class), + JacksonSdkType.of(TestUnaryStringIO.class)) { @Override public void expand(SdkWorkflowBuilder builder) { builder.output("string", builder.inputOfString("string")); @@ -147,8 +158,10 @@ public void expand(SdkWorkflowBuilder builder) { @Test public void testWithFixedInput_missing() { - SdkWorkflow workflow = - new SdkWorkflow<>(JacksonSdkType.of(TestUnaryStringOutput.class)) { + SdkWorkflow workflow = + new SdkWorkflow<>( + JacksonSdkType.of(TestUnaryStringIO.class), + JacksonSdkType.of(TestUnaryStringIO.class)) { @Override public void expand(SdkWorkflowBuilder builder) { builder.output("string", builder.inputOfString("string")); @@ -167,8 +180,10 @@ public void expand(SdkWorkflowBuilder builder) { @Test public void testWithFixedInput_illegalType() { - SdkWorkflow workflow = - new SdkWorkflow<>(JacksonSdkType.of(TestUnaryStringOutput.class)) { + SdkWorkflow workflow = + new SdkWorkflow<>( + JacksonSdkType.of(TestUnaryStringIO.class), + JacksonSdkType.of(TestUnaryStringIO.class)) { @Override public void expand(SdkWorkflowBuilder builder) { builder.output("string", builder.inputOfString("string")); @@ -187,11 +202,14 @@ public void expand(SdkWorkflowBuilder builder) { @Test public void testWithTask_missingRemoteTask() { - SdkWorkflow workflow = - new SdkWorkflow<>(SdkTypes.nulls()) { + SdkWorkflow workflow = + new SdkWorkflow<>(SdkTypes.nulls(), SdkTypes.nulls()) { @Override public void expand(SdkWorkflowBuilder builder) { - builder.apply("sum", RemoteSumTask.create().withInput("a", 1L).withInput("b", 2L)); + builder.apply( + "sum", + RemoteSumTask.create(), + RemoteSumInput.create(SdkBindingData.ofInteger(1L), SdkBindingData.ofInteger(2L))); } }; @@ -208,11 +226,14 @@ public void expand(SdkWorkflowBuilder builder) { @Test public void testWithTask_missingRemoteTaskOutput() { - SdkWorkflow workflow = - new SdkWorkflow<>(SdkTypes.nulls()) { + SdkWorkflow workflow = + new SdkWorkflow<>(SdkTypes.nulls(), SdkTypes.nulls()) { @Override public void expand(SdkWorkflowBuilder builder) { - builder.apply("sum", RemoteSumTask.create().withInput("a", 1L).withInput("b", 2L)); + builder.apply( + "sum", + RemoteSumTask.create(), + RemoteSumInput.create(SdkBindingData.ofInteger(1L), SdkBindingData.ofInteger(2L))); } }; @@ -223,7 +244,8 @@ public void expand(SdkWorkflowBuilder builder) { SdkTestingExecutor.of(workflow) .withTaskOutput( RemoteSumTask.create(), - RemoteSumInput.create(10L, 20L), + RemoteSumInput.create( + SdkBindingData.ofInteger(10L), SdkBindingData.ofInteger(20L)), RemoteSumOutput.create(30L)) .execute()); @@ -235,11 +257,12 @@ public void expand(SdkWorkflowBuilder builder) { @Test public void testWithTask_nullOutput() { - SdkWorkflow workflow = - new SdkWorkflow<>(SdkTypes.nulls()) { + SdkWorkflow workflow = + new SdkWorkflow<>(SdkTypes.nulls(), SdkTypes.nulls()) { @Override public void expand(SdkWorkflowBuilder builder) { - builder.apply("void", RemoteVoidOutputTask.create().withInput("ignore", "")); + builder.apply( + "void", RemoteVoidOutputTask.create(), RemoteVoidOutputTask.Input.create("")); } }; @@ -256,16 +279,16 @@ public void expand(SdkWorkflowBuilder builder) { public void withWorkflowOutput_successfullyMocksWhenTypeMatches() { SdkTestingExecutor.Result result = SdkTestingExecutor.of(new SimpleUberWorkflow()) - .withFixedInput("n", 7) + .withFixedInput("integer", 7) .withWorkflowOutput( new SimpleSubWorkflow(), - JacksonSdkType.of(SimpleSubWorkflowInput.class), - SimpleSubWorkflowInput.create(SdkBindingData.ofInteger(7)), - JacksonSdkType.of(TestUnaryIntegerOutput.class), - TestUnaryIntegerOutput.create(5)) + JacksonSdkType.of(TestUnaryIntegerIO.class), + TestUnaryIntegerIO.create(SdkBindingData.ofInteger(7)), + JacksonSdkType.of(TestUnaryIntegerIO.class), + TestUnaryIntegerIO.create(SdkBindingData.ofInteger(5))) .execute(); - assertThat(result.getIntegerOutput("o"), equalTo(5L)); + assertThat(result.getIntegerOutput("integer"), equalTo(5L)); } @Test @@ -278,17 +301,17 @@ public void testWithLaunchPlanOutput() { JacksonSdkType.of(SumLaunchPlanInput.class), JacksonSdkType.of(SumLaunchPlanOutput.class)); - SdkWorkflow workflow = - new SdkWorkflow<>(JacksonSdkType.of(TestUnaryIntegerOutput.class)) { + SdkWorkflow workflow = + new SdkWorkflow<>( + JacksonSdkType.of(SumLaunchPlanInput.class), + JacksonSdkType.of(TestUnaryIntegerIO.class)) { @Override public void expand(SdkWorkflowBuilder builder) { + SdkBindingData a = builder.inputOfInteger("a"); + SdkBindingData b = builder.inputOfInteger("b"); SdkBindingData c = builder - .apply( - "launchplanref", - launchplanRef - .withInput("a", builder.inputOfInteger("a")) - .withInput("b", builder.inputOfInteger("b"))) + .apply("launchplanref", launchplanRef, SumLaunchPlanInput.create(a, b)) .getOutputs() .c(); @@ -301,7 +324,10 @@ public void expand(SdkWorkflowBuilder builder) { .withFixedInput("a", 3L) .withFixedInput("b", 5L) .withLaunchPlanOutput( - launchplanRef, SumLaunchPlanInput.create(3L, 5L), SumLaunchPlanOutput.create(8L)) + launchplanRef, + SumLaunchPlanInput.create( + SdkBindingData.ofInteger(3L), SdkBindingData.ofInteger(5L)), + SumLaunchPlanOutput.create(SdkBindingData.ofInteger(8L))) .execute(); assertThat(result.getIntegerOutput("o"), equalTo(8L)); @@ -317,21 +343,21 @@ public void testWithLaunchPlanOutput_isMissing() { JacksonSdkType.of(SumLaunchPlanInput.class), JacksonSdkType.of(SumLaunchPlanOutput.class)); - SdkWorkflow workflow = - new SdkWorkflow<>(JacksonSdkType.of(TestUnaryIntegerOutput.class)) { + SdkWorkflow workflow = + new SdkWorkflow<>( + JacksonSdkType.of(SumLaunchPlanInput.class), + JacksonSdkType.of(TestUnaryIntegerIO.class)) { @Override public void expand(SdkWorkflowBuilder builder) { + SdkBindingData a = builder.inputOfInteger("a"); + SdkBindingData b = builder.inputOfInteger("b"); SdkBindingData c = builder - .apply( - "launchplanref", - launchplanRef - .withInput("a", builder.inputOfInteger("a")) - .withInput("b", builder.inputOfInteger("b"))) + .apply("launchplanref", launchplanRef, SumLaunchPlanInput.create(a, b)) .getOutputs() .c(); - builder.output("o", c); + builder.output("integer", c); } }; @@ -345,8 +371,9 @@ public void expand(SdkWorkflowBuilder builder) { .withLaunchPlanOutput( launchplanRef, // The stub values won't be matched, so exception iis throws - SumLaunchPlanInput.create(100000L, 100000L), - SumLaunchPlanOutput.create(8L)) + SumLaunchPlanInput.create( + SdkBindingData.ofInteger(100000L), SdkBindingData.ofInteger(100000L)), + SumLaunchPlanOutput.create(SdkBindingData.ofInteger(8L))) .execute()); assertThat( @@ -365,21 +392,21 @@ public void testWithLaunchPlan() { JacksonSdkType.of(SumLaunchPlanInput.class), JacksonSdkType.of(SumLaunchPlanOutput.class)); - SdkWorkflow workflow = - new SdkWorkflow<>(JacksonSdkType.of(TestUnaryIntegerOutput.class)) { + SdkWorkflow workflow = + new SdkWorkflow<>( + JacksonSdkType.of(SumLaunchPlanInput.class), + JacksonSdkType.of(TestUnaryIntegerIO.class)) { @Override public void expand(SdkWorkflowBuilder builder) { + SdkBindingData a = builder.inputOfInteger("a"); + SdkBindingData b = builder.inputOfInteger("b"); SdkBindingData c = builder - .apply( - "launchplanref", - launchplanRef - .withInput("a", builder.inputOfInteger("a")) - .withInput("b", builder.inputOfInteger("b"))) + .apply("launchplanref", launchplanRef, SumLaunchPlanInput.create(a, b)) .getOutputs() .c(); - builder.output("o", c); + builder.output("integer", c); } }; @@ -388,19 +415,25 @@ public void expand(SdkWorkflowBuilder builder) { .withFixedInput("a", 30L) .withFixedInput("b", 5L) .withLaunchPlan( - launchplanRef, in -> SumLaunchPlanOutput.create(in.a().get() + in.b().get())) + launchplanRef, + in -> + SumLaunchPlanOutput.create( + SdkBindingData.ofInteger(in.a().get() + in.b().get()))) .execute(); - assertThat(result.getIntegerOutput("o"), equalTo(35L)); + assertThat(result.getIntegerOutput("integer"), equalTo(35L)); } @Test public void testWithLaunchPlan_missingRemoteTaskOutput() { - SdkWorkflow workflow = - new SdkWorkflow<>(SdkTypes.nulls()) { + SdkWorkflow workflow = + new SdkWorkflow<>(SdkTypes.nulls(), SdkTypes.nulls()) { @Override public void expand(SdkWorkflowBuilder builder) { - builder.apply("sum", RemoteSumTask.create().withInput("a", 1L).withInput("b", 2L)); + builder.apply( + "sum", + RemoteSumTask.create(), + RemoteSumInput.create(SdkBindingData.ofInteger(1L), SdkBindingData.ofInteger(2L))); } }; @@ -411,7 +444,8 @@ public void expand(SdkWorkflowBuilder builder) { SdkTestingExecutor.of(workflow) .withTaskOutput( RemoteSumTask.create(), - RemoteSumInput.create(10L, 20L), + RemoteSumInput.create( + SdkBindingData.ofInteger(10L), SdkBindingData.ofInteger(20L)), RemoteSumOutput.create(30L)) .execute()); @@ -421,39 +455,37 @@ public void expand(SdkWorkflowBuilder builder) { "Can't find input RemoteSumInput{a=SdkBindingData{idl=BindingData{scalar=Scalar{primitive=Primitive{integerValue=1}}}, type=LiteralType{simpleType=INTEGER}, value=1}, b=SdkBindingData{idl=BindingData{scalar=Scalar{primitive=Primitive{integerValue=2}}}, type=LiteralType{simpleType=INTEGER}, value=2}} for remote task [remote_sum_task] across known task inputs, use SdkTestingExecutor#withTaskOutput or SdkTestingExecutor#withTask to provide a test double")); } - public static class SimpleUberWorkflow extends SdkWorkflow { + public static class SimpleUberWorkflow + extends SdkWorkflow { public SimpleUberWorkflow() { - super(JacksonSdkType.of(TestUnaryIntegerOutput.class)); + super( + JacksonSdkType.of(TestUnaryIntegerIO.class), JacksonSdkType.of(TestUnaryIntegerIO.class)); } @Override public void expand(SdkWorkflowBuilder builder) { - SdkBindingData input = builder.inputOfInteger("n", ""); + SdkBindingData input = builder.inputOfInteger("integer", ""); SdkBindingData output = - builder.apply("void", new SimpleSubWorkflow().withInput("in", input)).getOutputs().o(); - builder.output("o", output); + builder + .apply("void", new SimpleSubWorkflow(), TestUnaryIntegerIO.create(input)) + .getOutputs() + .integer(); + builder.output("integer", output); } } - public static class SimpleSubWorkflow extends SdkWorkflow { + public static class SimpleSubWorkflow + extends SdkWorkflow { public SimpleSubWorkflow() { - super(JacksonSdkType.of(TestUnaryIntegerOutput.class)); + super( + JacksonSdkType.of(TestUnaryIntegerIO.class), JacksonSdkType.of(TestUnaryIntegerIO.class)); } @Override public void expand(SdkWorkflowBuilder builder) { - builder.output("o", builder.inputOfInteger("in")); - } - } - - @AutoValue - abstract static class SimpleSubWorkflowInput { - abstract SdkBindingData in(); - - public static SimpleSubWorkflowInput create(SdkBindingData in) { - return new AutoValue_SdkTestingExecutorTest_SimpleSubWorkflowInput(in); + builder.output("integer", builder.inputOfInteger("integer")); } } @@ -461,9 +493,8 @@ public static SimpleSubWorkflowInput create(SdkBindingData in) { abstract static class SimpleSubWorkflowOutput { abstract SdkBindingData out(); - public static SimpleSubWorkflowOutput create(long out) { - return new AutoValue_SdkTestingExecutorTest_SimpleSubWorkflowOutput( - SdkBindingData.ofInteger(out)); + public static SimpleSubWorkflowOutput create(SdkBindingData out) { + return new AutoValue_SdkTestingExecutorTest_SimpleSubWorkflowOutput(out); } } @@ -473,9 +504,8 @@ abstract static class SumLaunchPlanInput { abstract SdkBindingData b(); - public static SumLaunchPlanInput create(long a, long b) { - return new AutoValue_SdkTestingExecutorTest_SumLaunchPlanInput( - SdkBindingData.ofInteger(a), SdkBindingData.ofInteger(b)); + public static SumLaunchPlanInput create(SdkBindingData a, SdkBindingData b) { + return new AutoValue_SdkTestingExecutorTest_SumLaunchPlanInput(a, b); } } @@ -483,8 +513,8 @@ public static SumLaunchPlanInput create(long a, long b) { abstract static class SumLaunchPlanOutput { abstract SdkBindingData c(); - public static SumLaunchPlanOutput create(long c) { - return new AutoValue_SdkTestingExecutorTest_SumLaunchPlanOutput(SdkBindingData.ofInteger(c)); + public static SumLaunchPlanOutput create(SdkBindingData c) { + return new AutoValue_SdkTestingExecutorTest_SumLaunchPlanOutput(c); } } } diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SumTask.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SumTask.java index 8d3638942..f302c65bf 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SumTask.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/SumTask.java @@ -36,9 +36,8 @@ public abstract static class SumInput { public abstract SdkBindingData b(); - public static SumInput create(long a, long b) { - return new AutoValue_SumTask_SumInput( - SdkBindingData.ofInteger(a), SdkBindingData.ofInteger(b)); + public static SumInput create(SdkBindingData a, SdkBindingData b) { + return new AutoValue_SumTask_SumInput(a, b); } } @@ -46,13 +45,13 @@ public static SumInput create(long a, long b) { public abstract static class SumOutput { public abstract SdkBindingData c(); - public static SumOutput create(long c) { - return new AutoValue_SumTask_SumOutput(SdkBindingData.ofInteger(c)); + public static SumOutput create(SdkBindingData c) { + return new AutoValue_SumTask_SumOutput(c); } } @Override public SumOutput run(SumInput input) { - return SumOutput.create(input.a().get() + input.b().get()); + return SumOutput.create(SdkBindingData.ofInteger(input.a().get() + input.b().get())); } } diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 023487c91..954c2a1c2 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT integration-tests diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/BranchNodeWorkflow.java b/integration-tests/src/main/java/org/flyte/integrationtests/BranchNodeWorkflow.java index bcb3559ee..678758245 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/BranchNodeWorkflow.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/BranchNodeWorkflow.java @@ -22,6 +22,7 @@ import static org.flyte.flytekit.SdkConditions.when; import com.google.auto.service.AutoService; +import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkCondition; import org.flyte.flytekit.SdkWorkflow; @@ -29,9 +30,32 @@ import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) -public class BranchNodeWorkflow extends SdkWorkflow { +public class BranchNodeWorkflow + extends SdkWorkflow { + + @AutoValue + abstract static class Input { + abstract SdkBindingData a(); + + abstract SdkBindingData b(); + + abstract SdkBindingData c(); + + abstract SdkBindingData d(); + + public static BranchNodeWorkflow.Input create( + SdkBindingData a, + SdkBindingData b, + SdkBindingData c, + SdkBindingData d) { + return new AutoValue_BranchNodeWorkflow_Input(a, b, c, d); + } + } + public BranchNodeWorkflow() { - super(JacksonSdkType.of(ConstStringTask.Output.class)); + super( + JacksonSdkType.of(BranchNodeWorkflow.Input.class), + JacksonSdkType.of(ConstStringTask.Output.class)); } @Override @@ -45,21 +69,57 @@ public void expand(SdkWorkflowBuilder builder) { when( "a-equal-b", eq(a, b), - when("c-equal-d", eq(c, d), ConstStringTask.of("a == b && c == d")) - .when("c-greater-d", gt(c, d), ConstStringTask.of("a == b && c > d")) - .when("c-less-d", lt(c, d), ConstStringTask.of("a == b && c < d"))) + when( + "c-equal-d", + eq(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(SdkBindingData.ofString("a == b && c == d"))) + .when( + "c-greater-d", + gt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(SdkBindingData.ofString("a == b && c > d"))) + .when( + "c-less-d", + lt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(SdkBindingData.ofString("a == b && c < d")))) .when( "a-less-b", lt(a, b), - when("c-equal-d", eq(c, d), ConstStringTask.of("a < b && c == d")) - .when("c-greater-d", gt(c, d), ConstStringTask.of("a < b && c > d")) - .when("c-less-d", lt(c, d), ConstStringTask.of("a < b && c < d"))) + when( + "c-equal-d", + eq(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(SdkBindingData.ofString("a < b && c == d"))) + .when( + "c-greater-d", + gt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(SdkBindingData.ofString("a < b && c > d"))) + .when( + "c-less-d", + lt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(SdkBindingData.ofString("a < b && c < d")))) .when( "a-greater-b", gt(a, b), - when("c-equal-d", eq(c, d), ConstStringTask.of("a > b && c == d")) - .when("c-greater-d", gt(c, d), ConstStringTask.of("a > b && c > d")) - .when("c-less-d", lt(c, d), ConstStringTask.of("a > b && c < d"))); + when( + "c-equal-d", + eq(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(SdkBindingData.ofString("a > b && c == d"))) + .when( + "c-greater-d", + gt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(SdkBindingData.ofString("a > b && c > d"))) + .when( + "c-less-d", + lt(c, d), + new ConstStringTask(), + ConstStringTask.Input.create(SdkBindingData.ofString("a > b && c < d")))); SdkBindingData value = builder.apply("condition", condition).getOutputs().value(); diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/ConstStringTask.java b/integration-tests/src/main/java/org/flyte/integrationtests/ConstStringTask.java index 24714f2a2..915371450 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/ConstStringTask.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/ConstStringTask.java @@ -20,7 +20,6 @@ import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkRunnableTask; -import org.flyte.flytekit.SdkTransform; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkRunnableTask.class) @@ -31,14 +30,18 @@ public class ConstStringTask @AutoValue abstract static class Input { abstract SdkBindingData value(); + + public static Input create(SdkBindingData value) { + return new AutoValue_ConstStringTask_Input(value); + } } @AutoValue abstract static class Output { abstract SdkBindingData value(); - public static Output create(String value) { - return new AutoValue_ConstStringTask_Output(SdkBindingData.ofString(value)); + public static Output create(SdkBindingData value) { + return new AutoValue_ConstStringTask_Output(value); } } @@ -46,12 +49,8 @@ public ConstStringTask() { super(JacksonSdkType.of(Input.class), JacksonSdkType.of(Output.class)); } - public static SdkTransform of(String value) { - return new ConstStringTask().withInput("value", SdkBindingData.ofString(value)); - } - @Override public Output run(Input input) { - return Output.create(input.value().get()); + return Output.create(input.value()); } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java index 9f9d098b1..06f92e65c 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java @@ -47,11 +47,11 @@ public abstract static class Input { abstract SdkBindingData tableName(); - public static Input create(String project, String dataset, String tableName) { - return new AutoValue_BuildBqReference_Input( - SdkBindingData.ofString(project), - SdkBindingData.ofString(dataset), - SdkBindingData.ofString(tableName)); + public static Input create( + SdkBindingData project, + SdkBindingData dataset, + SdkBindingData tableName) { + return new AutoValue_BuildBqReference_Input(project, dataset, tableName); } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java index b7f35b658..27c2747a3 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java @@ -37,7 +37,8 @@ public abstract static class Input { public abstract SdkBindingData checkIfExists(); - public static Input create(BQReference ref, boolean checkIfExists) { + public static Input create( + SdkBindingData ref, SdkBindingData checkIfExists) { return null; // TODO } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java index 9abef2f62..0eb6f734c 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java @@ -19,17 +19,22 @@ import static org.flyte.flytekit.SdkBindingData.ofBoolean; import static org.flyte.flytekit.SdkBindingData.ofString; -import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; -@AutoService(SdkWorkflow.class) -public class MockPipelineWorkflow extends SdkWorkflow { +// This workflow relays on SdkBinding that should be serialized +// as Struct. By going to typed inputs and outputs, we have de-scoped the support +// of structs. +// @AutoService(SdkWorkflow.class) +public class MockPipelineWorkflow + extends SdkWorkflow { public MockPipelineWorkflow() { - super(JacksonSdkType.of(MockPipelineWorkflow.Output.class)); + super( + JacksonSdkType.of(MockPipelineWorkflow.Input.class), + JacksonSdkType.of(MockPipelineWorkflow.Output.class)); } @Override @@ -39,30 +44,37 @@ public void expand(SdkWorkflowBuilder builder) { builder .apply( "build-ref", - new BuildBqReference() - .withInput("project", ofString("styx-1265")) - .withInput("dataset", ofString("styx-insights")) - .withInput("tableName", tableName)) + new BuildBqReference(), + BuildBqReference.Input.create( + ofString("styx-1265"), ofString("styx-insights"), tableName)) .getOutputs() .ref(); SdkBindingData exists = builder .apply( "lookup", - new MockLookupBqTask() - .withInput("ref", ref) - .withInput("checkIfExists", ofBoolean(true))) + new MockLookupBqTask(), + MockLookupBqTask.Input.create(ref, ofBoolean(true))) .getOutputs() .exists(); builder.output("exists", exists); } + @AutoValue + public abstract static class Input { + public abstract SdkBindingData tableName(); + + public static Input create(SdkBindingData tableName) { + return new AutoValue_MockPipelineWorkflow_Input(tableName); + } + } + @AutoValue public abstract static class Output { public abstract SdkBindingData exists(); - public static Output create(Boolean exists) { - return new AutoValue_MockPipelineWorkflow_Output(SdkBindingData.ofBoolean(exists)); + public static Output create(SdkBindingData exists) { + return new AutoValue_MockPipelineWorkflow_Output(exists); } } } diff --git a/jflyte-api/pom.xml b/jflyte-api/pom.xml index 5e332b14a..a502febfd 100644 --- a/jflyte-api/pom.xml +++ b/jflyte-api/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT jflyte-api diff --git a/jflyte-aws/pom.xml b/jflyte-aws/pom.xml index 2350fc954..3b55b6f27 100644 --- a/jflyte-aws/pom.xml +++ b/jflyte-aws/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT jflyte-aws diff --git a/jflyte-google-cloud/pom.xml b/jflyte-google-cloud/pom.xml index 3ca6eee72..4acb72867 100644 --- a/jflyte-google-cloud/pom.xml +++ b/jflyte-google-cloud/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT jflyte-google-cloud diff --git a/jflyte/pom.xml b/jflyte/pom.xml index 52d463e1a..190f07ed3 100644 --- a/jflyte/pom.xml +++ b/jflyte/pom.xml @@ -21,7 +21,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT jflyte diff --git a/pom.xml b/pom.xml index 80e52e960..4ad45e804 100644 --- a/pom.xml +++ b/pom.xml @@ -20,7 +20,7 @@ org.flyte flytekit-parent - 0.3.29-SNAPSHOT + 0.4.0-SNAPSHOT pom