Skip to content

Commit

Permalink
Add typed transform inputs (#175)
Browse files Browse the repository at this point in the history
* Add InputT to SdkTransform (1)

... flytekit-java compiles but tests fails

Signed-off-by: Nelson Arapé <[email protected]>

* Seed SdkType implementations

Signed-off-by: Nelson Arapé <[email protected]>

* Fix examples

Signed-off-by: Andres Gomez Ferrer <[email protected]>

* Seed SdkType implementations

Signed-off-by: Nelson Arapé <[email protected]>

* Fix conditionals and more

Signed-off-by: Nelson Arapé <[email protected]>

* Local engine classes

Signed-off-by: Nelson Arapé <[email protected]>

* Examples in scala

Signed-off-by: Nelson Arapé <[email protected]>

* testing

Signed-off-by: Nelson Arapé <[email protected]>

* Compile works

Signed-off-by: Andres Gomez Ferrer <[email protected]>

* Fix spotless

Signed-off-by: Andres Gomez Ferrer <[email protected]>

* Fix integration test

Signed-off-by: Nelson Arapé <[email protected]>

* Only convert InputT to binding in SdkTransform

Signed-off-by: Nelson Arapé <[email protected]>

* Move check to SdkTransform

Signed-off-by: Nelson Arapé <[email protected]>

* Address feedback

Signed-off-by: Nelson Arapé <[email protected]>

* Bump minor version to denote breaking change

Signed-off-by: Nelson Arapé <[email protected]>

Signed-off-by: Nelson Arapé <[email protected]>
Signed-off-by: Andres Gomez Ferrer <[email protected]>
Co-authored-by: Andres Gomez Ferrer <[email protected]>
  • Loading branch information
narape and andresgomezfrr authored Jan 25, 2023
1 parent 57ad454 commit 2a40a16
Show file tree
Hide file tree
Showing 107 changed files with 1,479 additions and 911 deletions.
2 changes: 1 addition & 1 deletion flyteidl-protos/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<parent>
<groupId>org.flyte</groupId>
<artifactId>flytekit-parent</artifactId>
<version>0.3.29-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
</parent>

<artifactId>flyteidl-protos</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion flytekit-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<parent>
<groupId>org.flyte</groupId>
<artifactId>flytekit-parent</artifactId>
<version>0.3.29-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
</parent>

<artifactId>flytekit-api</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion flytekit-examples-scala/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<parent>
<groupId>org.flyte</groupId>
<artifactId>flytekit-parent</artifactId>
<version>0.3.29-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
</parent>

<artifactId>flytekit-examples-scala</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,3 @@ class AddQuestionTask
override def run(input: AddQuestionTaskInput): AddQuestionTaskOutput =
AddQuestionTaskOutput(ofString(s"${input.greeting.get} How are you?"))
}

object AddQuestionTask {

/** Binds input data to this task
*
* @param greeting
* the input greeting message
* @return
* a transformed instance of this class with input data
*/
def apply(
greeting: SdkBindingData[String]
): SdkTransform[AddQuestionTaskOutput] =
new AddQuestionTask().withInput("greeting", greeting)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ import org.flyte.flytekitscala.{
SdkScalaWorkflowBuilder
}

case class DynamicFibonacciWorkflowInput(n: SdkBindingData[Long])
case class DynamicFibonacciWorkflowOutput(output: SdkBindingData[Long])
class DynamicFibonacciWorkflow
extends SdkScalaWorkflow[DynamicFibonacciWorkflowOutput](
extends SdkScalaWorkflow[
DynamicFibonacciWorkflowInput,
DynamicFibonacciWorkflowOutput
](
SdkScalaType[DynamicFibonacciWorkflowInput],
SdkScalaType[DynamicFibonacciWorkflowOutput]
) {

Expand All @@ -34,7 +39,8 @@ class DynamicFibonacciWorkflow

val fibonacci = builder.apply(
"fibonacci",
new DynamicFibonacciWorkflowTask().withInput("n", n)
new DynamicFibonacciWorkflowTask(),
DynamicFibonacciWorkflowTaskInput(n)
)

builder.output("output", fibonacci.getOutputs.output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ class DynamicFibonacciWorkflowTask
else
fib(
n + 1,
builder(s"fib-${n + 1}", SumTask(value, prev)).getOutputs.c,
builder(
s"fib-${n + 1}",
new SumTask(),
SumTaskInput(value, prev)
).getOutputs.c,
value
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,45 @@
*/
package org.flyte.examples.flytekitscala

import org.flyte.flytekit.{SdkBindingData, SdkWorkflow, SdkWorkflowBuilder}
import org.flyte.flytekit.SdkBindingData
import org.flyte.flytekitscala.{
SdkScalaWorkflowBuilder,
SdkScalaType,
SdkScalaWorkflow
}

case class FibonacciWorkflowInput(
fib0: SdkBindingData[Long],
fib1: SdkBindingData[Long]
)
case class FibonacciWorkflowOutput(fib5: SdkBindingData[Long])

class FibonacciWorkflow
extends SdkScalaWorkflow[FibonacciWorkflowOutput](
extends SdkScalaWorkflow[FibonacciWorkflowInput, FibonacciWorkflowOutput](
SdkScalaType[FibonacciWorkflowInput],
SdkScalaType[FibonacciWorkflowOutput]
) {

override def expand(builder: SdkScalaWorkflowBuilder): Unit = {
val fib0 = builder.inputOfInteger("fib0", "Value for Fib0")
val fib1 = builder.inputOfInteger("fib1", "Value for Fib1")

val fib2 = builder.apply("fib-2", SumTask(fib0, fib1)).getOutputs.c
val fib3 = builder.apply("fib-3", SumTask(fib1, fib2)).getOutputs.c
val fib4 = builder.apply("fib-4", SumTask(fib2, fib3)).getOutputs.c
val fib5 = builder.apply("fib-5", SumTask(fib3, fib4)).getOutputs.c
val fib2 = builder
.apply("fib-2", new SumTask(), SumTaskInput(fib0, fib1))
.getOutputs
.c
val fib3 = builder
.apply("fib-3", new SumTask(), SumTaskInput(fib1, fib2))
.getOutputs
.c
val fib4 = builder
.apply("fib-4", new SumTask(), SumTaskInput(fib2, fib3))
.getOutputs
.c
val fib5 = builder
.apply("fib-5", new SumTask(), SumTaskInput(fib3, fib4))
.getOutputs
.c

builder.output("fib5", fib5, "Value for Fib5")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,3 @@ class GreetTask
override def run(input: GreetTaskInput): GreetTaskOutput =
GreetTaskOutput(ofString(s"Welcome, ${input.name.get()}!"))
}

object GreetTask {

/** Binds input data to this task
*
* @param name
* the input name
* @return
* a transformed instance of this class with input data
*/
def apply(name: SdkBindingData[String]): SdkTransform[GreetTaskOutput] =
new GreetTask().withInput("name", name)
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,3 @@ class SumTask

override def isCacheSerializable: Boolean = true
}

object SumTask {
def apply(
a: SdkBindingData[Long],
b: SdkBindingData[Long]
): SdkTransform[SumTaskOutput] =
new SumTask().withInput("a", a).withInput("b", b)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.flyte.examples.flytekitscala

import org.flyte.flytekit.{SdkBindingData, SdkWorkflow, SdkWorkflowBuilder}
import org.flyte.flytekit.SdkBindingData
import org.flyte.flytekitscala.{
SdkScalaType,
SdkScalaWorkflow,
Expand Down Expand Up @@ -49,10 +49,12 @@ import org.flyte.flytekitscala.{
* | output: greeting(string) |
*/

case class WelcomeWorkflowInput(name: SdkBindingData[String])
case class WelcomeWorkflowOutput(greeting: SdkBindingData[String])

class WelcomeWorkflow
extends SdkScalaWorkflow[WelcomeWorkflowOutput](
extends SdkScalaWorkflow[WelcomeWorkflowInput, WelcomeWorkflowOutput](
SdkScalaType[WelcomeWorkflowInput],
SdkScalaType[WelcomeWorkflowOutput]
) {

Expand All @@ -61,11 +63,18 @@ class WelcomeWorkflow
val name = builder.inputOfString("name", "The name for the welcome message")

// uses the workflow input as the task input of the GreetTask
val greeting = builder.apply("greet", GreetTask(name)).getOutputs.greeting
val greeting = builder
.apply("greet", new GreetTask(), GreetTaskInput(name))
.getOutputs
.greeting

// uses the output of the GreetTask as the task input of the AddQuestionTask
val greetingWithQuestion = builder
.apply("add-question", AddQuestionTask(greeting))
.apply(
"add-question",
new AddQuestionTask(),
AddQuestionTaskInput(greeting)
)
.getOutputs
.greeting

Expand Down
2 changes: 1 addition & 1 deletion flytekit-examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<parent>
<groupId>org.flyte</groupId>
<artifactId>flytekit-parent</artifactId>
<version>0.3.29-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
</parent>

<artifactId>flytekit-examples</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.google.auto.value.AutoValue;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkTransform;
import org.flyte.flytekit.jackson.JacksonSdkType;

/**
Expand All @@ -34,23 +33,17 @@ public AddQuestionTask() {
super(JacksonSdkType.of(Input.class), JacksonSdkType.of(Output.class));
}

/**
* Binds input data to this task.
*
* @param greeting the input greeting message
* @return a transformed instance of this class with input data
*/
public static SdkTransform<AddQuestionTask.Output> of(SdkBindingData<String> greeting) {
return new AddQuestionTask().withInput("greeting", greeting);
}

/**
* Generate an immutable value class that represents {@link AddQuestionTask}'s input, which is a
* String.
*/
@AutoValue
public abstract static class Input {
public abstract SdkBindingData<String> greeting();

public static Input create(SdkBindingData<String> greeting) {
return new AutoValue_AddQuestionTask_Input(greeting);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
*/
package org.flyte.examples;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.time.Duration;
Expand All @@ -27,7 +24,6 @@
import java.util.Map;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkTransform;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkRunnableTask.class)
Expand All @@ -38,30 +34,6 @@ public AllInputsTask() {
JacksonSdkType.of(AutoAllInputsInput.class), JacksonSdkType.of(AutoAllInputsOutput.class));
}

public static SdkTransform<AllInputsTask.AutoAllInputsOutput> of(
SdkBindingData<Long> i,
SdkBindingData<Double> f,
SdkBindingData<String> s,
SdkBindingData<Boolean> b,
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AllInputsTask()
.withInput("i", i)
.withInput("f", f)
.withInput("s", s)
.withInput("b", b)
.withInput("t", t)
.withInput("d", d)
.withInput("l", l)
.withInput("m", m)
.withInput("emptyList", emptyList)
.withInput("emptyMap", emptyMap);
}

@AutoValue
public abstract static class AutoAllInputsInput {
public abstract SdkBindingData<Long> i();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@
import org.flyte.examples.AllInputsTask.AutoAllInputsOutput;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkNode;
import org.flyte.flytekit.SdkTypes;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
public class AllInputsWorkflow extends SdkWorkflow<AllInputsWorkflow.AllInputsWorkflowOutput> {
public class AllInputsWorkflow
extends SdkWorkflow<Void, AllInputsWorkflow.AllInputsWorkflowOutput> {

public AllInputsWorkflow() {
super(JacksonSdkType.of(AllInputsWorkflow.AllInputsWorkflowOutput.class));
super(SdkTypes.nulls(), JacksonSdkType.of(AllInputsWorkflow.AllInputsWorkflowOutput.class));
}

@Override
Expand All @@ -46,7 +48,8 @@ public void expand(SdkWorkflowBuilder builder) {
SdkNode<AutoAllInputsOutput> apply =
builder.apply(
"all-inputs",
AllInputsTask.of(
new AllInputsTask(),
AllInputsTask.AutoAllInputsInput.create(
SdkBindingData.ofInteger(1L),
SdkBindingData.ofFloat(2),
SdkBindingData.ofString("test"),
Expand All @@ -57,6 +60,7 @@ public void expand(SdkWorkflowBuilder builder) {
SdkBindingData.ofStringMap(Map.of("test", "test")),
SdkBindingData.ofStringCollection(Collections.emptyList()),
SdkBindingData.ofIntegerMap(Collections.emptyMap())));

AllInputsTask.AutoAllInputsOutput outputs = apply.getOutputs();

builder.output("i", outputs.i(), "Integer value");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ public abstract static class Input {
public abstract SdkBindingData<Map<String, String>> keyValues();

public abstract SdkBindingData<List<String>> searchKeys();

public static Input create(
SdkBindingData<Map<String, String>> keyValues, SdkBindingData<List<String>> searchKeys) {
return new AutoValue_BatchLookUpTask_Input(keyValues, searchKeys);
}
}

@AutoValue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
public class ConditionalGreetingWorkflow extends SdkWorkflow<GreetTask.Output> {
public class ConditionalGreetingWorkflow extends SdkWorkflow<GreetTask.Input, GreetTask.Output> {
public ConditionalGreetingWorkflow() {
super(JacksonSdkType.of(GreetTask.Output.class));
super(JacksonSdkType.of(GreetTask.Input.class), JacksonSdkType.of(GreetTask.Output.class));
}

@Override
Expand All @@ -40,8 +40,11 @@ public void expand(SdkWorkflowBuilder builder) {
.apply(
"decide",
SdkConditions.when(
"when-empty", eq(name, ofString("")), GreetTask.of(ofString("World")))
.otherwise("when-not-empty", GreetTask.of(name)))
"when-empty",
eq(name, ofString("")),
new GreetTask(),
GreetTask.Input.create(ofString("World")))
.otherwise("when-not-empty", new GreetTask(), GreetTask.Input.create(name)))
.getOutputs()
.greeting();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

/** Example workflow that takes a name and outputs a welcome message. */
@AutoService(SdkWorkflow.class)
public class ContainerWorkflow extends SdkWorkflow<Void> {
public class ContainerWorkflow extends SdkWorkflow<Void, Void> {

public ContainerWorkflow() {
super(SdkTypes.nulls());
super(SdkTypes.nulls(), SdkTypes.nulls());
}

@Override
Expand Down
Loading

0 comments on commit 2a40a16

Please sign in to comment.