Skip to content

Commit

Permalink
Fix conditionals and more
Browse files Browse the repository at this point in the history
Signed-off-by: Nelson Arapé <[email protected]>
  • Loading branch information
narape committed Jan 23, 2023
1 parent c73703c commit 9561bcb
Show file tree
Hide file tree
Showing 25 changed files with 228 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.google.auto.value.AutoValue;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkTransform;
import org.flyte.flytekit.jackson.JacksonSdkType;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
*/
package org.flyte.examples;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.time.Duration;
Expand All @@ -27,7 +24,6 @@
import java.util.Map;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkTransform;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkRunnableTask.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
public class AllInputsWorkflow extends SdkWorkflow<Void, AllInputsWorkflow.AllInputsWorkflowOutput> {
public class AllInputsWorkflow
extends SdkWorkflow<Void, AllInputsWorkflow.AllInputsWorkflowOutput> {

public AllInputsWorkflow() {
super(SdkTypes.nulls(), JacksonSdkType.of(AllInputsWorkflow.AllInputsWorkflowOutput.class));
Expand All @@ -47,7 +48,8 @@ public void expand(SdkWorkflowBuilder builder) {
SdkNode<AutoAllInputsOutput> apply =
builder.apply(
"all-inputs",
new AllInputsTask(), AllInputsTask.AutoAllInputsInput.create(
new AllInputsTask(),
AllInputsTask.AutoAllInputsInput.create(
SdkBindingData.ofInteger(1L),
SdkBindingData.ofFloat(2),
SdkBindingData.ofString("test"),
Expand All @@ -57,8 +59,7 @@ public void expand(SdkWorkflowBuilder builder) {
SdkBindingData.ofStringCollection(Arrays.asList("foo", "bar")),
SdkBindingData.ofStringMap(Map.of("test", "test")),
SdkBindingData.ofStringCollection(Collections.emptyList()),
SdkBindingData.ofIntegerMap(Collections.emptyMap()))
);
SdkBindingData.ofIntegerMap(Collections.emptyMap())));

AllInputsTask.AutoAllInputsOutput outputs = apply.getOutputs();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ public abstract static class Input {

public abstract SdkBindingData<List<String>> searchKeys();

public static Input create(SdkBindingData<Map<String, String>> keyValues, SdkBindingData<List<String>> searchKeys) {
public static Input create(
SdkBindingData<Map<String, String>> keyValues, SdkBindingData<List<String>> searchKeys) {
return new AutoValue_BatchLookUpTask_Input(keyValues, searchKeys);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ public void expand(SdkWorkflowBuilder builder) {
.apply(
"decide",
SdkConditions.when(
"when-empty", eq(name, ofString("")), new GreetTask(), GreetTask.Input.create(ofString("World")))
"when-empty",
eq(name, ofString("")),
new GreetTask(),
GreetTask.Input.create(ofString("World")))
.otherwise("when-not-empty", new GreetTask(), GreetTask.Input.create(name)))
.getOutputs()
.greeting();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.util.List;
import java.util.Map;
import org.flyte.examples.BatchLookUpTask.Input;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
public class DynamicFibonacciWorkflow extends SdkWorkflow<DynamicFibonacciWorkflow.Input, DynamicFibonacciWorkflowTask.Output> {
public class DynamicFibonacciWorkflow
extends SdkWorkflow<DynamicFibonacciWorkflow.Input, DynamicFibonacciWorkflowTask.Output> {
@AutoValue
public abstract static class Input {
public abstract SdkBindingData<Long> n();
Expand All @@ -36,8 +34,11 @@ public static DynamicFibonacciWorkflow.Input create(SdkBindingData<Long> n) {
return new AutoValue_DynamicFibonacciWorkflow_Input(n);
}
}

public DynamicFibonacciWorkflow() {
super(JacksonSdkType.of(DynamicFibonacciWorkflow.Input.class), JacksonSdkType.of(DynamicFibonacciWorkflowTask.Output.class));
super(
JacksonSdkType.of(DynamicFibonacciWorkflow.Input.class),
JacksonSdkType.of(DynamicFibonacciWorkflowTask.Output.class));
}

@Override
Expand All @@ -46,7 +47,10 @@ public void expand(SdkWorkflowBuilder builder) {

SdkBindingData<Long> fibOutput =
builder
.apply("fibonacci", new DynamicFibonacciWorkflowTask(), DynamicFibonacciWorkflowTask.Input.create(n))
.apply(
"fibonacci",
new DynamicFibonacciWorkflowTask(),
DynamicFibonacciWorkflowTask.Input.create(n))
.getOutputs()
.output();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,30 @@
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkWorkflow.class)
public class FibonacciWorkflow extends SdkWorkflow<FibonacciWorkflow.Input, FibonacciWorkflow.Output> {
public class FibonacciWorkflow
extends SdkWorkflow<FibonacciWorkflow.Input, FibonacciWorkflow.Output> {

public FibonacciWorkflow() {
super(JacksonSdkType.of(FibonacciWorkflow.Input.class), JacksonSdkType.of(FibonacciWorkflow.Output.class));
super(
JacksonSdkType.of(FibonacciWorkflow.Input.class),
JacksonSdkType.of(FibonacciWorkflow.Output.class));
}

@Override
public void expand(SdkWorkflowBuilder builder) {
SdkBindingData<Long> fib0 = builder.inputOfInteger("fib0", "Value for Fib0");
SdkBindingData<Long> fib1 = builder.inputOfInteger("fib1", "Value for Fib1");

SdkNode<SumTask.SumOutput> apply = builder.apply("fib-2", new SumTask(), SumTask.SumInput.create(fib1, fib0));
SdkNode<SumTask.SumOutput> apply =
builder.apply("fib-2", new SumTask(), SumTask.SumInput.create(fib1, fib0));
SumTask.SumOutput outputs = apply.getOutputs();
SdkBindingData<Long> fib2 = outputs.c();
SdkBindingData<Long> fib3 = builder.apply("fib-3", new SumTask(), SumTask.SumInput.create(fib1, fib2)).getOutputs().c();
SdkBindingData<Long> fib4 = builder.apply("fib-4", new SumTask(), SumTask.SumInput.create(fib2, fib3)).getOutputs().c();
SdkBindingData<Long> fib5 = builder.apply("fib-5", new SumTask(), SumTask.SumInput.create(fib3, fib4)).getOutputs().c();
SdkBindingData<Long> fib3 =
builder.apply("fib-3", new SumTask(), SumTask.SumInput.create(fib1, fib2)).getOutputs().c();
SdkBindingData<Long> fib4 =
builder.apply("fib-4", new SumTask(), SumTask.SumInput.create(fib2, fib3)).getOutputs().c();
SdkBindingData<Long> fib5 =
builder.apply("fib-5", new SumTask(), SumTask.SumInput.create(fib3, fib4)).getOutputs().c();

builder.output("fib5", fib5, "Value for Fib5");
}
Expand All @@ -52,7 +59,8 @@ public abstract static class Input {

public abstract SdkBindingData<Long> fib1();

public static FibonacciWorkflow.Output create(SdkBindingData<Long> fib0, SdkBindingData<Long> fib1) {
public static FibonacciWorkflow.Output create(
SdkBindingData<Long> fib0, SdkBindingData<Long> fib1) {
return new AutoValue_FibonacciWorkflow_Input(fib0, fib1);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.google.auto.value.AutoValue;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkTransform;
import org.flyte.flytekit.SdkTypes;
import org.flyte.flytekit.jackson.JacksonSdkType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.google.auto.value.AutoValue;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkTransform;
import org.flyte.flytekit.jackson.JacksonSdkType;

@AutoService(SdkRunnableTask.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import org.flyte.examples.SumTask.SumInput;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import org.flyte.examples.SumTask.SumInput;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
Expand All @@ -38,7 +37,9 @@ public static WelcomeWorkflow.Input create(SdkBindingData<String> name) {
}

public WelcomeWorkflow() {
super(JacksonSdkType.of(WelcomeWorkflow.Input.class), JacksonSdkType.of(AddQuestionTask.Output.class));
super(
JacksonSdkType.of(WelcomeWorkflow.Input.class),
JacksonSdkType.of(AddQuestionTask.Output.class));
}

@Override
Expand All @@ -48,11 +49,17 @@ public void expand(SdkWorkflowBuilder builder) {

// uses the workflow input as the task input of the GreetTask
SdkBindingData<String> greeting =
builder.apply("greet", new GreetTask(), GreetTask.Input.create(name)).getOutputs().greeting();
builder
.apply("greet", new GreetTask(), GreetTask.Input.create(name))
.getOutputs()
.greeting();

// uses the output of the GreetTask as the task input of the AddQuestionTask
SdkBindingData<String> greetingWithQuestion =
builder.apply("add-question", new AddQuestionTask(), AddQuestionTask.Input.create(greeting)).getOutputs().greeting();
builder
.apply("add-question", new AddQuestionTask(), AddQuestionTask.Input.create(greeting))
.getOutputs()
.greeting();

// uses the task output of the AddQuestionTask as the output of the workflow
builder.output("greeting", greetingWithQuestion, "Welcome message");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2021 Flyte Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit;

import java.util.List;
import java.util.Set;
import javax.annotation.Nullable;

class SdkAppliedTransform<OriginalInputT, OutputT> extends SdkTransform<Void, OutputT> {
private final SdkTransform<OriginalInputT, OutputT> transform;
private final OriginalInputT appliedInputs;

SdkAppliedTransform(
SdkTransform<OriginalInputT, OutputT> transform, @Nullable OriginalInputT appliedInputs) {
checkNotNull(transform, appliedInputs);
this.transform = transform;
this.appliedInputs = appliedInputs;
}

static <InputT> void checkNotNull(SdkTransform<InputT, ?> transform, @Nullable InputT inputs) {
Set<String> variableNames = transform.getInputType().variableNames();
if (inputs == null && !variableNames.isEmpty()) {
throw new IllegalArgumentException(
String.format(
"Null supplied as input for a transform with %s properties", variableNames));
}
}

@Override
public SdkType<Void> getInputType() {
return SdkTypes.nulls();
}

@Override
public SdkType<OutputT> getOutputType() {
return transform.getOutputType();
}

@Override
public String getName() {
return transform.getName();
}

@Override
public SdkNode<OutputT> apply(
SdkWorkflowBuilder builder,
String nodeId,
List<String> upstreamNodeIds,
@Nullable SdkNodeMetadata metadata,
@Nullable Void inputs) {
return transform.apply(builder, nodeId, upstreamNodeIds, metadata, appliedInputs);
}
}
17 changes: 12 additions & 5 deletions flytekit-java/src/main/java/org/flyte/flytekit/SdkBranchNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.flyte.flytekit;

import static org.flyte.flytekit.MoreCollectors.toUnmodifiableList;
import static java.util.stream.Collectors.toUnmodifiableList;
import static org.flyte.flytekit.MoreCollectors.toUnmodifiableMap;

import com.google.errorprone.annotations.CanIgnoreReturnValue;
Expand All @@ -26,6 +26,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.flyte.api.v1.Binding;
import org.flyte.api.v1.BranchNode;
import org.flyte.api.v1.IfElseBlock;
Expand Down Expand Up @@ -89,15 +90,21 @@ public Node toIdl() {
ifElseBlock = ifElseBlock.toBuilder().error(nodeError).build();
}

// inputs in var order for predictability
List<Binding> inputs =
extraInputs.entrySet().stream()
.sorted(Entry.comparingByKey())
.map(Entry::getValue)
.collect(toUnmodifiableList());
return Node.builder()
.id(nodeId)
.branchNode(BranchNode.builder().ifElse(ifElseBlock).build())
.inputs(List.copyOf(extraInputs.values()))
.inputs(inputs)
.upstreamNodeIds(upstreamNodeIds)
.build();
}

static class Builder<InputT, OutputT> {
static class Builder<OutputT> {
private final SdkWorkflowBuilder builder;
private final SdkType<OutputT> outputType;

Expand All @@ -113,7 +120,7 @@ static class Builder<InputT, OutputT> {
}

@CanIgnoreReturnValue
Builder<InputT, OutputT> addCase(SdkConditionCase<InputT, OutputT> case_) {
Builder<OutputT> addCase(SdkConditionCase<OutputT> case_) {
SdkNode<OutputT> sdkNode =
case_
.then()
Expand Down Expand Up @@ -147,7 +154,7 @@ Builder<InputT, OutputT> addCase(SdkConditionCase<InputT, OutputT> case_) {
}

@CanIgnoreReturnValue
Builder<InputT, OutputT> addOtherwise(String name, SdkTransform<InputT, OutputT> otherwise) {
Builder<OutputT> addOtherwise(String name, SdkTransform<Void, OutputT> otherwise) {
if (elseNode != null) {
throw new IllegalArgumentException("Duplicate otherwise clause");
}
Expand Down
Loading

0 comments on commit 9561bcb

Please sign in to comment.