Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typed transform inputs #175

Merged
merged 16 commits into from
Jan 25, 2023
Merged
2 changes: 1 addition & 1 deletion flyteidl-protos/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<parent>
<groupId>org.flyte</groupId>
<artifactId>flytekit-parent</artifactId>
<version>0.3.29-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
</parent>

<artifactId>flyteidl-protos</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion flytekit-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<parent>
<groupId>org.flyte</groupId>
<artifactId>flytekit-parent</artifactId>
<version>0.3.29-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
</parent>

<artifactId>flytekit-api</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion flytekit-examples-scala/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<parent>
<groupId>org.flyte</groupId>
<artifactId>flytekit-parent</artifactId>
<version>0.3.29-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
</parent>

<artifactId>flytekit-examples-scala</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,3 @@ class AddQuestionTask
override def run(input: AddQuestionTaskInput): AddQuestionTaskOutput =
AddQuestionTaskOutput(ofString(s"${input.greeting.get} How are you?"))
}

object AddQuestionTask {

/** Binds input data to this task
*
* @param greeting
* the input greeting message
* @return
* a transformed instance of this class with input data
*/
def apply(
andresgomezfrr marked this conversation as resolved.
Show resolved Hide resolved
greeting: SdkBindingData[String]
): SdkTransform[AddQuestionTaskOutput] =
new AddQuestionTask().withInput("greeting", greeting)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ import org.flyte.flytekitscala.{
SdkScalaWorkflowBuilder
}

case class DynamicFibonacciWorkflowInput(n: SdkBindingData[Long])
case class DynamicFibonacciWorkflowOutput(output: SdkBindingData[Long])
class DynamicFibonacciWorkflow
extends SdkScalaWorkflow[DynamicFibonacciWorkflowOutput](
extends SdkScalaWorkflow[
DynamicFibonacciWorkflowInput,
DynamicFibonacciWorkflowOutput
](
SdkScalaType[DynamicFibonacciWorkflowInput],
SdkScalaType[DynamicFibonacciWorkflowOutput]
) {

Expand All @@ -34,7 +39,8 @@ class DynamicFibonacciWorkflow

val fibonacci = builder.apply(
"fibonacci",
new DynamicFibonacciWorkflowTask().withInput("n", n)
new DynamicFibonacciWorkflowTask(),
DynamicFibonacciWorkflowTaskInput(n)
)

builder.output("output", fibonacci.getOutputs.output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ class DynamicFibonacciWorkflowTask
else
fib(
n + 1,
builder(s"fib-${n + 1}", SumTask(value, prev)).getOutputs.c,
builder(
s"fib-${n + 1}",
new SumTask(),
SumTaskInput(value, prev)
).getOutputs.c,
value
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,45 @@
*/
package org.flyte.examples.flytekitscala

import org.flyte.flytekit.{SdkBindingData, SdkWorkflow, SdkWorkflowBuilder}
import org.flyte.flytekit.SdkBindingData
import org.flyte.flytekitscala.{
SdkScalaWorkflowBuilder,
SdkScalaType,
SdkScalaWorkflow
}

case class FibonacciWorkflowInput(
fib0: SdkBindingData[Long],
fib1: SdkBindingData[Long]
)
case class FibonacciWorkflowOutput(fib5: SdkBindingData[Long])

class FibonacciWorkflow
extends SdkScalaWorkflow[FibonacciWorkflowOutput](
extends SdkScalaWorkflow[FibonacciWorkflowInput, FibonacciWorkflowOutput](
SdkScalaType[FibonacciWorkflowInput],
SdkScalaType[FibonacciWorkflowOutput]
) {

override def expand(builder: SdkScalaWorkflowBuilder): Unit = {
val fib0 = builder.inputOfInteger("fib0", "Value for Fib0")
val fib1 = builder.inputOfInteger("fib1", "Value for Fib1")

val fib2 = builder.apply("fib-2", SumTask(fib0, fib1)).getOutputs.c
val fib3 = builder.apply("fib-3", SumTask(fib1, fib2)).getOutputs.c
val fib4 = builder.apply("fib-4", SumTask(fib2, fib3)).getOutputs.c
val fib5 = builder.apply("fib-5", SumTask(fib3, fib4)).getOutputs.c
val fib2 = builder
.apply("fib-2", new SumTask(), SumTaskInput(fib0, fib1))
.getOutputs
.c
val fib3 = builder
.apply("fib-3", new SumTask(), SumTaskInput(fib1, fib2))
.getOutputs
.c
val fib4 = builder
.apply("fib-4", new SumTask(), SumTaskInput(fib2, fib3))
.getOutputs
.c
val fib5 = builder
.apply("fib-5", new SumTask(), SumTaskInput(fib3, fib4))
.getOutputs
.c

builder.output("fib5", fib5, "Value for Fib5")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,3 @@ class GreetTask
override def run(input: GreetTaskInput): GreetTaskOutput =
GreetTaskOutput(ofString(s"Welcome, ${input.name.get()}!"))
}

object GreetTask {

/** Binds input data to this task
*
* @param name
* the input name
* @return
* a transformed instance of this class with input data
*/
def apply(name: SdkBindingData[String]): SdkTransform[GreetTaskOutput] =
new GreetTask().withInput("name", name)
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,3 @@ class SumTask

override def isCacheSerializable: Boolean = true
}

object SumTask {
def apply(
a: SdkBindingData[Long],
b: SdkBindingData[Long]
): SdkTransform[SumTaskOutput] =
new SumTask().withInput("a", a).withInput("b", b)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.flyte.examples.flytekitscala

import org.flyte.flytekit.{SdkBindingData, SdkWorkflow, SdkWorkflowBuilder}
import org.flyte.flytekit.SdkBindingData
import org.flyte.flytekitscala.{
SdkScalaType,
SdkScalaWorkflow,
Expand Down Expand Up @@ -49,10 +49,12 @@ import org.flyte.flytekitscala.{
* | output: greeting(string) |
*/

case class WelcomeWorkflowInput(name: SdkBindingData[String])
case class WelcomeWorkflowOutput(greeting: SdkBindingData[String])

class WelcomeWorkflow
extends SdkScalaWorkflow[WelcomeWorkflowOutput](
extends SdkScalaWorkflow[WelcomeWorkflowInput, WelcomeWorkflowOutput](
SdkScalaType[WelcomeWorkflowInput],
SdkScalaType[WelcomeWorkflowOutput]
) {

Expand All @@ -61,11 +63,18 @@ class WelcomeWorkflow
val name = builder.inputOfString("name", "The name for the welcome message")

// uses the workflow input as the task input of the GreetTask
val greeting = builder.apply("greet", GreetTask(name)).getOutputs.greeting
val greeting = builder
.apply("greet", new GreetTask(), GreetTaskInput(name))
.getOutputs
.greeting

// uses the output of the GreetTask as the task input of the AddQuestionTask
val greetingWithQuestion = builder
.apply("add-question", AddQuestionTask(greeting))
.apply(
"add-question",
new AddQuestionTask(),
AddQuestionTaskInput(greeting)
)
.getOutputs
.greeting

Expand Down
2 changes: 1 addition & 1 deletion flytekit-examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<parent>
<groupId>org.flyte</groupId>
<artifactId>flytekit-parent</artifactId>
<version>0.3.29-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
</parent>

<artifactId>flytekit-examples</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.google.auto.value.AutoValue;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkTransform;
import org.flyte.flytekit.jackson.JacksonSdkType;

/**
Expand All @@ -34,23 +33,17 @@ public AddQuestionTask() {
super(JacksonSdkType.of(Input.class), JacksonSdkType.of(Output.class));
}

/**
* Binds input data to this task.
*
* @param greeting the input greeting message
* @return a transformed instance of this class with input data
*/
public static SdkTransform<AddQuestionTask.Output> of(SdkBindingData<String> greeting) {
return new AddQuestionTask().withInput("greeting", greeting);
}

/**
* Generate an immutable value class that represents {@link AddQuestionTask}'s input, which is a
* String.
*/
@AutoValue
public abstract static class Input {
public abstract SdkBindingData<String> greeting();

public static Input create(SdkBindingData<String> greeting) {
Copy link
Contributor Author

@narape narape Jan 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewers

Before we were only using the AutoValue to define the type of the task, but now we actually create the auto values to pass it in the apply

return new AutoValue_AddQuestionTask_Input(greeting);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
*/
package org.flyte.examples;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.time.Duration;
Expand All @@ -27,7 +24,6 @@
import java.util.Map;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkTransform;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkRunnableTask.class)
Expand All @@ -38,30 +34,6 @@ public AllInputsTask() {
JacksonSdkType.of(AutoAllInputsInput.class), JacksonSdkType.of(AutoAllInputsOutput.class));
}

public static SdkTransform<AllInputsTask.AutoAllInputsOutput> of(
SdkBindingData<Long> i,
SdkBindingData<Double> f,
SdkBindingData<String> s,
SdkBindingData<Boolean> b,
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<String>> emptyList,
SdkBindingData<Map<String, Long>> emptyMap) {
return new AllInputsTask()
.withInput("i", i)
.withInput("f", f)
.withInput("s", s)
.withInput("b", b)
.withInput("t", t)
.withInput("d", d)
.withInput("l", l)
.withInput("m", m)
.withInput("emptyList", emptyList)
.withInput("emptyMap", emptyMap);
}

@AutoValue
public abstract static class AutoAllInputsInput {
public abstract SdkBindingData<Long> i();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@
import org.flyte.examples.AllInputsTask.AutoAllInputsOutput;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkNode;
import org.flyte.flytekit.SdkTypes;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
public class AllInputsWorkflow extends SdkWorkflow<AllInputsWorkflow.AllInputsWorkflowOutput> {
public class AllInputsWorkflow
extends SdkWorkflow<Void, AllInputsWorkflow.AllInputsWorkflowOutput> {

public AllInputsWorkflow() {
super(JacksonSdkType.of(AllInputsWorkflow.AllInputsWorkflowOutput.class));
super(SdkTypes.nulls(), JacksonSdkType.of(AllInputsWorkflow.AllInputsWorkflowOutput.class));
}

@Override
Expand All @@ -46,7 +48,8 @@ public void expand(SdkWorkflowBuilder builder) {
SdkNode<AutoAllInputsOutput> apply =
builder.apply(
"all-inputs",
AllInputsTask.of(
new AllInputsTask(),
AllInputsTask.AutoAllInputsInput.create(
SdkBindingData.ofInteger(1L),
SdkBindingData.ofFloat(2),
SdkBindingData.ofString("test"),
Expand All @@ -57,6 +60,7 @@ public void expand(SdkWorkflowBuilder builder) {
SdkBindingData.ofStringMap(Map.of("test", "test")),
SdkBindingData.ofStringCollection(Collections.emptyList()),
SdkBindingData.ofIntegerMap(Collections.emptyMap())));

AllInputsTask.AutoAllInputsOutput outputs = apply.getOutputs();

builder.output("i", outputs.i(), "Integer value");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ public abstract static class Input {
public abstract SdkBindingData<Map<String, String>> keyValues();

public abstract SdkBindingData<List<String>> searchKeys();

public static Input create(
SdkBindingData<Map<String, String>> keyValues, SdkBindingData<List<String>> searchKeys) {
return new AutoValue_BatchLookUpTask_Input(keyValues, searchKeys);
}
}

@AutoValue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
public class ConditionalGreetingWorkflow extends SdkWorkflow<GreetTask.Output> {
public class ConditionalGreetingWorkflow extends SdkWorkflow<GreetTask.Input, GreetTask.Output> {
public ConditionalGreetingWorkflow() {
super(JacksonSdkType.of(GreetTask.Output.class));
super(JacksonSdkType.of(GreetTask.Input.class), JacksonSdkType.of(GreetTask.Output.class));
}

@Override
Expand All @@ -40,8 +40,11 @@ public void expand(SdkWorkflowBuilder builder) {
.apply(
"decide",
SdkConditions.when(
"when-empty", eq(name, ofString("")), GreetTask.of(ofString("World")))
.otherwise("when-not-empty", GreetTask.of(name)))
"when-empty",
eq(name, ofString("")),
new GreetTask(),
GreetTask.Input.create(ofString("World")))
.otherwise("when-not-empty", new GreetTask(), GreetTask.Input.create(name)))
.getOutputs()
.greeting();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

/** Example workflow that takes a name and outputs a welcome message. */
@AutoService(SdkWorkflow.class)
public class ContainerWorkflow extends SdkWorkflow<Void> {
public class ContainerWorkflow extends SdkWorkflow<Void, Void> {

public ContainerWorkflow() {
super(SdkTypes.nulls());
super(SdkTypes.nulls(), SdkTypes.nulls());
}

@Override
Expand Down
Loading