From 57ad4540348bc2ef487cfd4be8d78beb5d9503fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20G=C3=B3mez?= Date: Tue, 24 Jan 2023 09:27:13 +0100 Subject: [PATCH] Add toSdkBindingDataMap() to SdkType (#177) * Add toSdkBindingDataMap() to SdkType class Signed-off-by: Andres Gomez Ferrer * Spotless:apply, fix compile and test Signed-off-by: Andres Gomez Ferrer * Review fixes Signed-off-by: Andres Gomez Ferrer Signed-off-by: Andres Gomez Ferrer Co-authored-by: Andres Gomez Ferrer --- .../flytekit/jackson/JacksonSdkType.java | 25 ++++++- .../flytekit/jackson/RootFormatVisitor.java | 9 +++ .../flytekit/jackson/VariableMapVisitor.java | 7 ++ .../flytekit/jackson/JacksonSdkTypeTest.java | 70 +++++++++++++++++++ .../main/java/org/flyte/flytekit/SdkType.java | 2 + .../java/org/flyte/flytekit/SdkTypes.java | 5 ++ .../flyte/flytekit/TestPairIntegerInput.java | 24 ++++--- .../flytekit/TestUnaryBooleanOutput.java | 5 ++ .../flytekit/TestUnaryIntegerOutput.java | 5 ++ .../examples/TestUnaryIntegerOutput.java | 5 ++ .../flytekitscala/SdkScalaTypeTest.scala | 44 ++++++++++++ .../flyte/flytekitscala/SdkScalaType.scala | 37 +++++++++- 12 files changed, 225 insertions(+), 13 deletions(-) diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkType.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkType.java index e87c436be..b85841c08 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkType.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkType.java @@ -26,12 +26,14 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.deser.DefaultDeserializationContext; +import com.fasterxml.jackson.databind.introspect.AnnotatedMember; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.fasterxml.jackson.module.paramnames.ParameterNamesModule; import java.io.IOException; import java.io.UncheckedIOException; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Variable; @@ -46,10 +48,13 @@ public class JacksonSdkType extends SdkType { private final Class clazz; private final Map variableMap; + private final Map membersMap; - private JacksonSdkType(Class clazz, Map variableMap) { + private JacksonSdkType( + Class clazz, Map variableMap, Map membersMap) { this.clazz = Objects.requireNonNull(clazz); this.variableMap = Objects.requireNonNull(variableMap); + this.membersMap = Objects.requireNonNull(membersMap); } public static JacksonSdkType of(Class clazz) { @@ -73,7 +78,7 @@ public static JacksonSdkType of(Class clazz) { serializer.acceptJsonFormatVisitor( visitor, OBJECT_MAPPER.getTypeFactory().constructType(clazz)); - return new JacksonSdkType<>(clazz, visitor.getVariableMap()); + return new JacksonSdkType<>(clazz, visitor.getVariableMap(), visitor.getMembersMap()); } catch (JsonMappingException e) { throw new IllegalArgumentException( String.format("Failed to find serializer for [%s]", clazz.getName()), e); @@ -120,6 +125,10 @@ public Map getVariableMap() { return variableMap; } + private Map getMembersMap() { + return membersMap; + } + @Override public T fromLiteralMap(Map value) { try { @@ -168,6 +177,18 @@ public T promiseFor(String nodeId) { } } + @Override + public Map> toSdkBindingMap(T value) { + return getMembersMap().entrySet().stream() + .map( + entry -> { + String attrName = entry.getKey(); + AnnotatedMember member = entry.getValue(); + return Map.entry(attrName, (SdkBindingData) member.getValue(value)); + }) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + private static ObjectMapper createObjectMapper(SdkTypeModule bindingMap) { return new ObjectMapper() .registerModule(bindingMap) diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/RootFormatVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/RootFormatVisitor.java index 3f3bbc8d7..b4142b5fa 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/RootFormatVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/RootFormatVisitor.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.introspect.AnnotatedMember; import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonFormatVisitorWrapper; import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonObjectFormatVisitor; import java.util.Map; @@ -44,4 +45,12 @@ public Map getVariableMap() { return builder.getVariableMap(); } + + public Map getMembersMap() { + if (builder == null) { + throw new IllegalStateException("invariant failed: membersMap not set"); + } + + return builder.getMembersMap(); + } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index 2d8915c7e..8da81d00f 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.introspect.AnnotatedMember; import com.fasterxml.jackson.databind.introspect.BeanPropertyDefinition; import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonObjectFormatVisitor; import java.time.Duration; @@ -62,6 +63,7 @@ class VariableMapVisitor extends JsonObjectFormatVisitor.Base { } private final Map builder = new LinkedHashMap<>(); + private final Map builderMembers = new LinkedHashMap<>(); @Override public void property(BeanProperty prop) { @@ -74,6 +76,7 @@ public void property(BeanProperty prop) { prop.getMember().getMember().getDeclaringClass().getName()); Variable variable = Variable.builder().description("").literalType(literalType).build(); + builderMembers.put(prop.getName(), prop.getMember()); builder.put(prop.getName(), variable); } @@ -107,6 +110,10 @@ public Map getVariableMap() { return unmodifiableMap(new HashMap<>(builder)); } + public Map getMembersMap() { + return unmodifiableMap(new HashMap<>(builderMembers)); + } + @SuppressWarnings("AlreadyChecked") private LiteralType toLiteralType( JavaType javaType, boolean rootLevel, String propName, String declaringClassName) { diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index acd26b410..2dde685c0 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -21,9 +21,12 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasEntry; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.util.StdConverter; @@ -283,6 +286,73 @@ void testToLiteralMap() { ))); } + @Test + public void testToSdkBindingDataMap() { + AutoValueInput input = + createAutoValueInput( + /* i= */ 42L, + /* f= */ 42.0d, + /* s= */ "42", + /* b= */ false, + /* t= */ Instant.ofEpochSecond(42, 1), + /* d= */ Duration.ofSeconds(1, 42), + /// * blob= */ blob, + /* l= */ List.of("foo"), + /* m= */ Map.of("marco", "polo"), + /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), + /* lm= */ List.of(Map.of("A", "a", "B", "b"), Map.of("a", "A", "b", "B")), + /* ml= */ Map.of("frodo", List.of("baggins", "bolson")), + /* mm= */ Map.of( + "math", Map.of("pi", "3.14", "e", "2.72"), "pokemon", Map.of("ash", "pikachu"))); + + Map> sdkBindingDataMap = + JacksonSdkType.of(AutoValueInput.class).toSdkBindingMap(input); + + Map> expected = new HashMap<>(); + expected.put("i", input.i()); + expected.put("f", input.f()); + expected.put("s", input.s()); + expected.put("b", input.b()); + expected.put("t", input.t()); + expected.put("d", input.d()); + expected.put("l", input.l()); + expected.put("m", input.m()); + expected.put("ll", input.ll()); + expected.put("lm", input.lm()); + expected.put("ml", input.ml()); + expected.put("mm", input.mm()); + + assertEquals(expected, sdkBindingDataMap); + } + + @Test + public void testToSdkBindingDataMapJsonProperties() { + + JsonPropertyClassInput input = + new JsonPropertyClassInput( + SdkBindingData.ofString("test"), SdkBindingData.ofString("name")); + + Map> sdkBindingDataMap = + JacksonSdkType.of(JsonPropertyClassInput.class).toSdkBindingMap(input); + + var expected = Map.of("test", input.test, "name", input.otherTest); + + assertEquals(expected, sdkBindingDataMap); + } + + public static class JsonPropertyClassInput { + @JsonProperty SdkBindingData test; + + @JsonProperty("name") + SdkBindingData otherTest; + + @JsonCreator + public JsonPropertyClassInput(SdkBindingData test, SdkBindingData otherTest) { + this.test = test; + this.otherTest = otherTest; + } + } + @Test public void testPojoToLiteralMap() { PojoInput input = new PojoInput(); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java index 70cd21efa..72e3dd517 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java @@ -29,4 +29,6 @@ public abstract class SdkType { public abstract T promiseFor(String nodeId); public abstract Map getVariableMap(); + + public abstract Map> toSdkBindingMap(T value); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java index 548033bf2..a2d628ace 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java @@ -50,5 +50,10 @@ public Void promiseFor(String nodeId) { public Map getVariableMap() { return Collections.emptyMap(); } + + @Override + public Map> toSdkBindingMap(Void value) { + return Collections.emptyMap(); + } } } diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/TestPairIntegerInput.java b/flytekit-java/src/test/java/org/flyte/flytekit/TestPairIntegerInput.java index 930729193..e71850a8e 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/TestPairIntegerInput.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/TestPairIntegerInput.java @@ -33,32 +33,40 @@ public static TestPairIntegerInput create(SdkBindingData a, SdkBindingData public static class SdkType extends org.flyte.flytekit.SdkType { + private static final String A = "a"; + private static final String B = "b"; + @Override public Map toLiteralMap(TestPairIntegerInput value) { return Map.of( - "a", Literals.ofInteger(value.a().get()), - "b", Literals.ofInteger(value.b().get())); + A, Literals.ofInteger(value.a().get()), + B, Literals.ofInteger(value.b().get())); } @Override public TestPairIntegerInput fromLiteralMap(Map value) { return create( - SdkBindingData.ofInteger(value.get("a").scalar().primitive().integerValue()), - SdkBindingData.ofInteger(value.get("b").scalar().primitive().integerValue())); + SdkBindingData.ofInteger(value.get(A).scalar().primitive().integerValue()), + SdkBindingData.ofInteger(value.get(B).scalar().primitive().integerValue())); } @Override public TestPairIntegerInput promiseFor(String nodeId) { return create( - SdkBindingData.ofOutputReference(nodeId, "a", LiteralTypes.INTEGER), - SdkBindingData.ofOutputReference(nodeId, "b", LiteralTypes.INTEGER)); + SdkBindingData.ofOutputReference(nodeId, A, LiteralTypes.INTEGER), + SdkBindingData.ofOutputReference(nodeId, B, LiteralTypes.INTEGER)); } @Override public Map getVariableMap() { return Map.of( - "a", Variable.builder().literalType(LiteralTypes.INTEGER).build(), - "b", Variable.builder().literalType(LiteralTypes.INTEGER).build()); + A, Variable.builder().literalType(LiteralTypes.INTEGER).build(), + B, Variable.builder().literalType(LiteralTypes.INTEGER).build()); + } + + @Override + public Map> toSdkBindingMap(TestPairIntegerInput value) { + return Map.of(A, value.a(), B, value.b()); } } } diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryBooleanOutput.java b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryBooleanOutput.java index 1b48d9e3e..0fc661ea7 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryBooleanOutput.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryBooleanOutput.java @@ -54,5 +54,10 @@ public TestUnaryBooleanOutput promiseFor(String nodeId) { public Map getVariableMap() { return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build()); } + + @Override + public Map> toSdkBindingMap(TestUnaryBooleanOutput value) { + return Map.of(VAR, value.o()); + } } } diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerOutput.java b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerOutput.java index 4f040b1be..0ae4b3028 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerOutput.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/TestUnaryIntegerOutput.java @@ -54,5 +54,10 @@ public TestUnaryIntegerOutput promiseFor(String nodeId) { public Map getVariableMap() { return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build()); } + + @Override + public Map> toSdkBindingMap(TestUnaryIntegerOutput value) { + return Map.of(VAR, value.o()); + } } } diff --git a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestUnaryIntegerOutput.java b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestUnaryIntegerOutput.java index 3d0c36067..337ff472e 100644 --- a/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestUnaryIntegerOutput.java +++ b/flytekit-local-engine/src/test/java/org/flyte/localengine/examples/TestUnaryIntegerOutput.java @@ -59,5 +59,10 @@ public TestUnaryIntegerOutput promiseFor(String nodeId) { public Map getVariableMap() { return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build()); } + + @Override + public Map> toSdkBindingMap(TestUnaryIntegerOutput value) { + return Map.of(VAR, value.o()); + } } } diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index ae9f1082b..6c6e2549e 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -32,6 +32,7 @@ import org.flyte.flytekitscala.SdkBindingData._ import org.junit.Assert.assertEquals import org.junit.Test import org.flyte.examples.AllInputsTask.AutoAllInputsInput +import org.junit.jupiter.api.Assertions.assertThrows class SdkScalaTypeTest { @@ -177,6 +178,49 @@ class SdkScalaTypeTest { assertEquals(expected, output) } + @Test + def testToSdkBindingMap(): Unit = { + val input = ScalarInput( + string = ofString("string"), + integer = ofInteger(1337L), + float = ofFloat(42.0), + boolean = ofBoolean(true), + datetime = ofDateTime(Instant.ofEpochMilli(123456L)), + duration = ofDuration(Duration.ofSeconds(123, 456)) + ) + + val output = SdkScalaType[ScalarInput].toSdkBindingMap(input) + + val expected = Map( + "string" -> ofString("string"), + "integer" -> ofInteger(1337L), + "float" -> ofFloat(42.0), + "boolean" -> ofBoolean(true), + "datetime" -> ofDateTime(Instant.ofEpochMilli(123456L)), + "duration" -> ofDuration(Duration.ofSeconds(123, 456)) + ).asJava + + assertEquals(expected, output) + } + + case class InputWithoutSdkBinding(notSdkBinding: String) + @Test + def testCaseClassWithoutSdkBindingData(): Unit = { + val exception = assertThrows( + classOf[IllegalStateException], + () => { + SdkScalaType[InputWithoutSdkBinding].toSdkBindingMap( + InputWithoutSdkBinding("test") + ) + } + ) + + assertEquals( + "All the fields of the case class InputWithoutSdkBinding must be SdkBindingData[_]", + exception.getMessage + ) + } + private def createCollectionVar(simpleType: SimpleType) = { Variable .builder() diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index e7a25ef13..82f459d72 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -20,7 +20,7 @@ import java.time.{Duration, Instant} import java.{util => ju} import magnolia.{CaseClass, Magnolia, Param, SealedTrait} import org.flyte.api.v1._ -import org.flyte.flytekit.{SdkBindingData => SdkJavaBindinigData, SdkType} +import org.flyte.flytekit.{SdkType, SdkBindingData => SdkJavaBindinigData} import scala.annotation.implicitNotFound import scala.collection.JavaConverters._ @@ -130,6 +130,31 @@ object SdkScalaType { ) }) } + + override def toSdkBindingMap( + value: T + ): ju.Map[String, SdkJavaBindinigData[_]] = { + value match { + case product: Product => + value.getClass.getDeclaredFields + .map(_.getName) + .zip(product.productIterator.toSeq) + .toMap + .mapValues { + case value: SdkJavaBindinigData[_] => value + case _ => + throw new IllegalStateException( + s"All the fields of the case class ${value.getClass.getSimpleName} must be SdkBindingData[_]" + ) + } + .toMap + .asJava + case _ => + throw new IllegalStateException( + s"The class ${value.getClass.getSimpleName} must be a case class" + ) + } + } } } @@ -300,12 +325,18 @@ object SdkScalaType { } private object SdkUnitType extends SdkScalaProductType[Unit] { - def getVariableMap: ju.Map[String, Variable] = ju.Collections.emptyMap() + def getVariableMap: ju.Map[String, Variable] = + Map.empty[String, Variable].asJava def toLiteralMap(value: Unit): ju.Map[String, Literal] = - ju.Collections.emptyMap() + Map.empty[String, Literal].asJava def fromLiteralMap(literal: ju.Map[String, Literal]): Unit = () def promiseFor(nodeId: String): Unit = () + + override def toSdkBindingMap( + value: Unit + ): ju.Map[String, SdkJavaBindinigData[_]] = + Map.empty[String, SdkJavaBindinigData[_]].asJava }