Skip to content

Commit

Permalink
Implement Max Parallelism in the LaunchPlan (#300)
Browse files Browse the repository at this point in the history
* Max Parallelism

Signed-off-by: Rafael Raposo <[email protected]>

* FMT

Signed-off-by: Rafael Raposo <[email protected]>

* Add example

Signed-off-by: Rafael Raposo <[email protected]>

* CheckStyle

Signed-off-by: Rafael Raposo <[email protected]>

* Missing piece - add to protoUtil

Signed-off-by: Rafael Raposo <[email protected]>

* Another layer

Signed-off-by: Rafael Raposo <[email protected]>

---------

Signed-off-by: Rafael Raposo <[email protected]>
  • Loading branch information
RRap0so authored Jun 28, 2024
1 parent de6417e commit ccc3964
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 5 deletions.
8 changes: 8 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;

/** User-provided launch plan definition and configuration values. */
Expand All @@ -40,6 +41,11 @@ public abstract class LaunchPlan {
*/
public abstract Map<String, Parameter> defaultInputs();

/**
* Controls the maximum number of tasknodes that can be run in parallel for the entire workflow.
*/
public abstract Optional<Integer> maxParallelism();

@Nullable
public abstract CronSchedule cronSchedule();

Expand All @@ -64,6 +70,8 @@ public abstract static class Builder {

public abstract Builder cronSchedule(CronSchedule cronSchedule);

public abstract Builder maxParallelism(Optional<Integer> maxParallelism);

public abstract LaunchPlan build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.util.Optional;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkBindingDataFactory;
import org.flyte.flytekit.SdkLaunchPlan;
Expand Down Expand Up @@ -53,10 +54,20 @@ public FibonacciLaunchPlan() {
.withName("FibonacciWorkflowLaunchPlan3")
.withDefaultInput("fib0", 0L)
.withDefaultInput("fib1", 1L));

// Register launch plan with fixed inputs and maxParallelism of 10
registerLaunchPlan(
SdkLaunchPlan.of(new FibonacciWorkflow())
.withName("FibonacciWorkflowLaunchPlan4")
.withFixedInputs(
JacksonSdkType.of(Input.class),
Input.create(SdkBindingDataFactory.of(0), SdkBindingDataFactory.of(1)))
.withMaxParallelism(Optional.of(10)));
}

@AutoValue
abstract static class Input {

abstract SdkBindingData<Long> fib0();

abstract SdkBindingData<Long> fib1();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.flyte.api.v1.Literal;
Expand Down Expand Up @@ -81,6 +82,10 @@ public abstract class SdkLaunchPlan {
@Nullable
public abstract SdkCronSchedule cronSchedule();

/** Returns the max parallelism of the launch plan. */
@Nullable
public abstract Optional<Integer> maxParallelism();

/**
* Creates a launch plan for specified {@link SdkLaunchPlan} with default naming, no inputs and no
* schedule. The default launch plan name is {@link SdkWorkflow#getName()}. New name, inputs and
Expand Down Expand Up @@ -322,6 +327,16 @@ public <T> SdkLaunchPlan withDefaultInput(SdkType<T> type, T value) {
v -> createParameter(v.getValue(), literalMap.get(v.getKey())))));
}

/**
* @param maxParallelism Optional Integer for the max parallelism (cannot be negative). Default
* Value: Empty, it will default to what's set in the Flyte Platform. 0: It will try to use as
* much as allowed.
* @return the new launch plan
*/
public SdkLaunchPlan withMaxParallelism(Optional<Integer> maxParallelism) {
return withMaxParallelism0(maxParallelism);
}

private SdkLaunchPlan withDefaultInputs0(Map<String, Parameter> newDefaultInputs) {

verifyNonEmptyWorkflowInput(newDefaultInputs, "default");
Expand All @@ -336,6 +351,17 @@ private SdkLaunchPlan withDefaultInputs0(Map<String, Parameter> newDefaultInputs
return toBuilder().defaultInputs(newCompleteDefaultInputs).build();
}

private SdkLaunchPlan withMaxParallelism0(Optional<Integer> maxParallelism) {
if (maxParallelism.isPresent() && maxParallelism.get() < 0) {
String message =
String.format(
"invalid max parallelism %s, expected a positive integer", maxParallelism.get());
throw new IllegalArgumentException(message);
}

return toBuilder().maxParallelism(maxParallelism).build();
}

private <T> Map<String, T> mergeInputs(
Map<String, T> oldInputs, Map<String, T> newInputs, String inputType) {
Map<String, T> newCompleteInputs = new LinkedHashMap<>(oldInputs);
Expand Down Expand Up @@ -388,7 +414,8 @@ static Builder builder() {
return new AutoValue_SdkLaunchPlan.Builder()
.fixedInputs(Collections.emptyMap())
.defaultInputs(Collections.emptyMap())
.workflowInputTypeMap(Collections.emptyMap());
.workflowInputTypeMap(Collections.emptyMap())
.maxParallelism(Optional.empty());
}

abstract Builder toBuilder();
Expand All @@ -414,6 +441,8 @@ abstract static class Builder {

abstract Builder workflowInputTypeMap(Map<String, LiteralType> workflowInputTypeMap);

abstract Builder maxParallelism(Optional<Integer> maxParallelism);

abstract SdkLaunchPlan build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ Map<LaunchPlanIdentifier, LaunchPlan> load(
.name(sdkLaunchPlan.name())
.workflowId(getWorkflowIdentifier(sdkLaunchPlan))
.fixedInputs(sdkLaunchPlan.fixedInputs())
.defaultInputs(sdkLaunchPlan.defaultInputs());
.defaultInputs(sdkLaunchPlan.defaultInputs())
.maxParallelism(sdkLaunchPlan.maxParallelism());

if (sdkLaunchPlan.cronSchedule() != null) {
builder.cronSchedule(getCronSchedule(sdkLaunchPlan.cronSchedule()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.flyte.api.v1.CronSchedule;
import org.flyte.api.v1.LaunchPlan;
import org.flyte.api.v1.LaunchPlanIdentifier;
Expand Down Expand Up @@ -137,6 +138,34 @@ void shouldTestLaunchPlansWithCronSchedule() {
hasEntry(expectedIdentifierWithOffset, planWithOffset)));
}

@Test
void shouldTestLaunchPlansWithMaxParallelism() {
Map<LaunchPlanIdentifier, LaunchPlan> launchPlans =
registrar.load(ENV, singletonList(new TestRegistryWithMaxParallelism()));

LaunchPlanIdentifier expectedIdentifierWithOffset =
LaunchPlanIdentifier.builder()
.project("project")
.domain("domain")
.name("TestPlanScheduleWithMaxParallelism")
.version("version")
.build();

LaunchPlan planWithOffset =
LaunchPlan.builder()
.name("TestPlanScheduleWithMaxParallelism")
.workflowId(
PartialWorkflowIdentifier.builder()
.name("org.flyte.flytekit.SdkLaunchPlanRegistrarTest$TestWorkflow")
.build())
.fixedInputs(Collections.emptyMap())
.defaultInputs(Collections.emptyMap())
.maxParallelism(Optional.of(10))
.build();

assertThat(launchPlans, allOf(hasEntry(expectedIdentifierWithOffset, planWithOffset)));
}

@Test
void shouldRejectLoadingLaunchPlanDuplicatesInSameRegistry() {
IllegalArgumentException exception =
Expand Down Expand Up @@ -208,6 +237,17 @@ public List<SdkLaunchPlan> getLaunchPlans() {
}
}

public static class TestRegistryWithMaxParallelism implements SdkLaunchPlanRegistry {

@Override
public List<SdkLaunchPlan> getLaunchPlans() {
return Arrays.asList(
SdkLaunchPlan.of(new TestWorkflow())
.withName("TestPlanScheduleWithMaxParallelism")
.withMaxParallelism(Optional.of(10)));
}
}

public static class TestWorkflow extends SdkWorkflow<Void, Void> {

public TestWorkflow() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.flyte.api.v1.Literal;
Expand Down Expand Up @@ -91,6 +92,14 @@ void shouldCreateLaunchPlanWithCronSchedule() {
assertThat(plan.cronSchedule().offset(), equalTo(Duration.ofHours(1)));
}

@Test
void shouldCreateLaunchPlanWithMaxParallelism() {
SdkLaunchPlan plan = SdkLaunchPlan.of(new TestWorkflow()).withMaxParallelism(Optional.of(123));

assertThat(plan.maxParallelism(), notNullValue());
assertThat(plan.maxParallelism().get(), equalTo(123));
}

@Test
void shouldAddFixedInputs() {
Instant now = Instant.now();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ LaunchPlan apply(LaunchPlan launchPlan) {
.defaultInputs(launchPlan.defaultInputs())
.workflowId(apply(launchPlan.workflowId()))
.cronSchedule(launchPlan.cronSchedule())
.maxParallelism(launchPlan.maxParallelism())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
/** Utility to serialize between flytekit-api and flyteidl proto. */
@SuppressWarnings("PreferJavaTimeOverload")
public class ProtoUtil {

public static final String RUNTIME_FLAVOR = "java";
public static final String RUNTIME_VERSION = "0.0.1";

Expand Down Expand Up @@ -717,6 +718,8 @@ static LaunchPlanOuterClass.LaunchPlanSpec serialize(LaunchPlan launchPlan) {
.setFixedInputs(ProtoUtil.serialize(launchPlan.fixedInputs()))
.setDefaultInputs(ProtoUtil.serializeParameters(launchPlan.defaultInputs()));

launchPlan.maxParallelism().ifPresent(specBuilder::setMaxParallelism);

if (launchPlan.cronSchedule() != null) {
ScheduleOuterClass.Schedule schedule = ProtoUtil.serialize(launchPlan.cronSchedule());
specBuilder.setEntityMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.Optional;
import org.flyte.api.v1.Binding;
import org.flyte.api.v1.BindingData;
import org.flyte.api.v1.CronSchedule;
Expand Down Expand Up @@ -219,6 +220,7 @@ public void shouldPropagateLaunchPlanToStub() {
LaunchPlan launchPlan =
LaunchPlan.builder()
.workflowId(wfIdentifier)
.maxParallelism(Optional.of(20))
.name(LP_NAME)
.fixedInputs(
Collections.singletonMap(
Expand Down Expand Up @@ -249,6 +251,7 @@ public void shouldPropagateLaunchPlanToStub() {
.setSpec(
LaunchPlanOuterClass.LaunchPlanSpec.newBuilder()
.setWorkflowId(newIdentifier(ResourceType.WORKFLOW, WF_NAME, WF_VERSION))
.setMaxParallelism(20)
.setFixedInputs(
Literals.LiteralMap.newBuilder()
.putLiterals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.flyte.jflyte.utils.ProtoUtil.serialize;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
Expand All @@ -38,6 +39,7 @@
import com.google.protobuf.ListValue;
import com.google.protobuf.NullValue;
import com.google.protobuf.Value;
import flyteidl.admin.LaunchPlanOuterClass.LaunchPlanSpec;
import flyteidl.admin.ScheduleOuterClass;
import flyteidl.core.Condition;
import flyteidl.core.DynamicJob;
Expand All @@ -56,6 +58,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.flyte.api.v1.Binary;
import org.flyte.api.v1.Binding;
Expand All @@ -75,6 +78,7 @@
import org.flyte.api.v1.IfBlock;
import org.flyte.api.v1.IfElseBlock;
import org.flyte.api.v1.KeyValuePair;
import org.flyte.api.v1.LaunchPlan;
import org.flyte.api.v1.LaunchPlanIdentifier;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
Expand Down Expand Up @@ -989,6 +993,27 @@ public void shouldSerializeCronSchedule() {
.build()));
}

@Test
public void shouldSerializeLaunchPlanMaxParallelism() {
Optional<Integer> maxParallelism = Optional.of(10);
LaunchPlan launchPlan =
LaunchPlan.builder()
.name("name")
.workflowId(
PartialWorkflowIdentifier.builder()
.project("test-project")
.domain("test-domain")
.version("a-version")
.name("name")
.build())
.maxParallelism(maxParallelism)
.build();

LaunchPlanSpec res = serialize(launchPlan);

assertThat(res.getMaxParallelism(), equalTo(10));
}

@Test
public void shouldSerializeCronScheduleNoOffset() {
CronSchedule cronSchedule = CronSchedule.builder().schedule("* * */5 * *").build();
Expand Down Expand Up @@ -1252,7 +1277,7 @@ void shouldSerializeContainerWithResources() {
Resources.ResourceName.CPU, "8", Resources.ResourceName.MEMORY, "32G"))
.build());

Tasks.Container actual = ProtoUtil.serialize(container);
Tasks.Container actual = serialize(container);

assertThat(
actual,
Expand Down Expand Up @@ -1296,7 +1321,7 @@ void shouldAcceptResourcesWithValidQuantities(String quantity) {
.limits(ImmutableMap.of(Resources.ResourceName.CPU, quantity))
.build());

Tasks.Container actual = ProtoUtil.serialize(container);
Tasks.Container actual = serialize(container);

assertThat(
actual,
Expand Down Expand Up @@ -1326,7 +1351,7 @@ void shouldRejectResourcesWithInvalidQuantities(String quantity) {
.build());

IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, () -> ProtoUtil.serialize(container));
assertThrows(IllegalArgumentException.class, () -> serialize(container));

assertEquals(
"Resource requests [CPU] has invalid quantity: " + quantity, exception.getMessage());
Expand Down

0 comments on commit ccc3964

Please sign in to comment.