Skip to content

Commit

Permalink
Auto generate node ID (#162)
Browse files Browse the repository at this point in the history
* Auto generate node ID

Signed-off-by: Hongxin Liang <[email protected]>

* Do not expose

Signed-off-by: Hongxin Liang <[email protected]>

* OK

Signed-off-by: Hongxin Liang <[email protected]>

* Use remote name

Signed-off-by: Hongxin Liang <[email protected]>

* Use node id prefix

Signed-off-by: Hongxin Liang <[email protected]>

* Give it a proper prefix

Signed-off-by: Hongxin Liang <[email protected]>

* Use alphabet for prefix

Signed-off-by: Hongxin Liang <[email protected]>

* Node id related methods in own class

Signed-off-by: Nelson Arapé <[email protected]>

* withNameOverride(String, boolean) -> withNameOverrideIfNotSet

Signed-off-by: Nelson Arapé <[email protected]>

* Rename NamePolicy and add docs

Signed-off-by: Nelson Arapé <[email protected]>

* Minor refactor test

Signed-off-by: Nelson Arapé <[email protected]>

Signed-off-by: Hongxin Liang <[email protected]>
Signed-off-by: Nelson Arapé <[email protected]>
Co-authored-by: Nelson Arapé <[email protected]>
Signed-off-by: Andres Gomez Ferrer <[email protected]>
  • Loading branch information
2 people authored and andresgomezfrr committed Jan 24, 2023
1 parent d0a0b83 commit 5a01149
Show file tree
Hide file tree
Showing 17 changed files with 311 additions and 47 deletions.
5 changes: 5 additions & 0 deletions flytekit-java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,10 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,6 @@ public SdkBranchNode apply(
List<String> upstreamNodeIds,
@Nullable SdkNodeMetadata metadata,
Map<String, SdkBindingData> inputs) {
if (metadata != null) {
throw new IllegalArgumentException("invariant failed: metadata must be null");
}
if (!inputs.isEmpty()) {
throw new IllegalArgumentException("invariant failed: inputs must be empty");
}

SdkBranchNode.Builder nodeBuilder = new SdkBranchNode.Builder(builder);

for (SdkConditionCase case_ : cases) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@ public String getType() {
return "raw-container";
}

/** Specifies task name. */
public String getName() {
return getClass().getName();
}

/** Specifies task input type. */
public SdkType<InputT> getInputType() {
return inputType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ public String getType() {
return "dynamic";
}

public String getName() {
return getClass().getName();
}

public SdkType<InputT> getInputType() {
return inputType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ public SdkNode apply(String id, SdkTransform transform) {
List<String> upstreamNodeIds =
getOutputs().isEmpty() ? Collections.singletonList(getNodeId()) : Collections.emptyList();

return builder.applyInternal(id, transform, upstreamNodeIds, /*metadata=*/ null, getOutputs());
return builder.applyInternal(id, transform, upstreamNodeIds, getOutputs());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2021 Flyte Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit;

import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;

/**
* Controls the default node id and node name policy when applying {@link SdkTransform} to {@link
* SdkWorkflowBuilder}. When using {@link SdkWorkflowBuilder#apply(SdkTransform)} or {@link
* SdkWorkflowBuilder#apply(SdkTransform, Map)} then the node id used would be the one returned by
* {@link #nextNodeId()}. Also, if a node name haven't been set by the user, then {@link
* #toNodeName(String)} would be used.
*/
class SdkNodeNamePolicy {
private static final Pattern UPPER_AFTER_LOWER_PATTERN = Pattern.compile("([a-z])([A-Z]+)");
private static final int RND_PREFIX_SIZE = 4;

private final String nodeIdPrefix;
private final AtomicInteger nodeIdSuffix;

SdkNodeNamePolicy() {
this.nodeIdPrefix = randomPrefix();
this.nodeIdSuffix = new AtomicInteger();
}

/**
* Returns a unique node ids in the format {@code <prefix>-n<consecutive-number>}, where prefix is
* a random, but shared among all ids for this object, set of character in the format {@code
* wRRRR} and {@code R} is a random letter in {@code a-z} range.
*
* @return next unique node id for this policy.
*/
String nextNodeId() {
return nodeIdPrefix + "n" + nodeIdSuffix.getAndIncrement();
}

/**
* Returns a node appropriate name for a given transformation name. The transformation done are
*
* <ul>
* <li>Package name is removed
* <li>CamelCase is transformed to kebab-case
* <li>$ is transformed to -
* </ul>
*
* <p>For example {@code com.example.Outer$InnerTask} get translated to {@code outer-inner-task}.
*
* @return node name.
*/
String toNodeName(String name) {
String lastPart = name.substring(name.lastIndexOf('.') + 1);
return UPPER_AFTER_LOWER_PATTERN
.matcher(lastPart)
.replaceAll("$1-$2")
.toLowerCase(Locale.ROOT)
.replaceAll("\\$", "-");
}

// Returns random prefix in the following format "wqjoz-"
private static String randomPrefix() {
return "w"
+ ThreadLocalRandom.current()
.ints(RND_PREFIX_SIZE, 'a', 'z' + 1)
.collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append)
.append('-');
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ public SdkTransform withNameOverride(String name) {
return new SdkPartialTransform(transform, fixedInputs, extraUpstreamNodeIds, mergedMetadata);
}

@Override
SdkTransform withNameOverrideIfNotSet(String name) {
if (metadata != null && metadata.name() != null) {
return this;
}
return withNameOverride(name);
}

@Override
public SdkTransform withTimeoutOverride(Duration timeout) {
requireNonNull(timeout, "Timeout override cannot be null");
Expand All @@ -114,6 +122,11 @@ public SdkTransform withTimeoutOverride(Duration timeout) {
return new SdkPartialTransform(transform, fixedInputs, extraUpstreamNodeIds, mergedMetadata);
}

@Override
public String getName() {
return transform.getName();
}

@Override
public SdkNode apply(
SdkWorkflowBuilder builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ public static <InputT, OutputT> SdkRemoteLaunchPlan<InputT, OutputT> create(
.build();
}

@Override
public String getName() {
return name();
}

@Override
public SdkNode apply(
SdkWorkflowBuilder builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ public static <InputT, OutputT> SdkRemoteTask<InputT, OutputT> create(
.build();
}

@Override
public String getName() {
return name();
}

@Override
public SdkNode apply(
SdkWorkflowBuilder builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ public String getType() {
return "java-task";
}

public String getName() {
return getClass().getName();
}

public SdkType<InputT> getInputType() {
return inputType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,18 @@ public SdkTransform withNameOverride(String name) {
return SdkPartialTransform.of(this, metadata);
}

SdkTransform withNameOverrideIfNotSet(String name) {
return withNameOverride(name);
}

public SdkTransform withTimeoutOverride(Duration timeout) {
requireNonNull(timeout, "Timeout override cannot be null");

SdkNodeMetadata metadata = SdkNodeMetadata.builder().timeout(timeout).build();
return SdkPartialTransform.of(this, metadata);
}

public String getName() {
return getClass().getName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@

public abstract class SdkWorkflow extends SdkTransform {

public String getName() {
return getClass().getName();
}

public abstract void expand(SdkWorkflowBuilder builder);

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,54 +27,80 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.SimpleType;
import org.flyte.api.v1.WorkflowTemplate;

public class SdkWorkflowBuilder {

private final Map<String, SdkNode> nodes;
private final Map<String, SdkBindingData> inputs;
private final Map<String, SdkBindingData> outputs;
private final Map<String, String> inputDescriptions;
private final Map<String, String> outputDescriptions;
private final SdkNodeNamePolicy sdkNodeNamePolicy;

public SdkWorkflowBuilder() {
this(new SdkNodeNamePolicy());
}

// VisibleForTesting
SdkWorkflowBuilder(SdkNodeNamePolicy sdkNodeNamePolicy) {
// Using LinkedHashMap to preserve declaration order
this.nodes = new LinkedHashMap<>();
this.inputs = new LinkedHashMap<>();
this.outputs = new LinkedHashMap<>();

this.inputDescriptions = new HashMap<>();
this.outputDescriptions = new HashMap<>();

this.sdkNodeNamePolicy = sdkNodeNamePolicy;
}

public SdkNode apply(String nodeId, SdkTransform transform) {
return apply(nodeId, transform, emptyMap());
}

public SdkNode apply(String nodeId, SdkTransform transform, Map<String, SdkBindingData> inputs) {
return applyInternal(nodeId, transform, emptyList(), /*metadata=*/ null, inputs);
return applyInternal(nodeId, transform, emptyList(), inputs);
}

public SdkNode apply(SdkTransform transform) {
return apply(/*nodeId=*/ null, transform, emptyMap());
}

public SdkNode apply(SdkTransform transform, Map<String, SdkBindingData> inputs) {
return applyInternal(/*nodeId=*/ null, transform, emptyList(), inputs);
}

protected SdkNode applyInternal(
String nodeId,
SdkNode applyInternal(
@Nullable String nodeId,
SdkTransform transform,
List<String> upstreamNodeIds,
@Nullable SdkNodeMetadata metadata,
Map<String, SdkBindingData> inputs) {

if (nodes.containsKey(nodeId)) {
String actualNodeId = Objects.requireNonNullElseGet(nodeId, sdkNodeNamePolicy::nextNodeId);

if (nodes.containsKey(actualNodeId)) {
CompilerError error =
CompilerError.create(
CompilerError.Kind.DUPLICATE_NODE_ID,
nodeId,
actualNodeId,
"Trying to insert two nodes with the same id.");

throw new CompilerException(error);
}

SdkNode sdkNode = transform.apply(this, nodeId, upstreamNodeIds, metadata, inputs);
String fallbackNodeName =
Objects.requireNonNullElseGet(
nodeId, () -> sdkNodeNamePolicy.toNodeName(transform.getName()));

SdkNode sdkNode =
transform
.withNameOverrideIfNotSet(fallbackNodeName)
.apply(this, actualNodeId, upstreamNodeIds, null, inputs);
nodes.put(sdkNode.getNodeId(), sdkNode);

return sdkNode;
Expand Down
Loading

0 comments on commit 5a01149

Please sign in to comment.