Skip to content

Commit

Permalink
Add toSdkBindingDataMap() to SdkType (#177)
Browse files Browse the repository at this point in the history
* Add toSdkBindingDataMap() to SdkType class

Signed-off-by: Andres Gomez Ferrer <[email protected]>

* Spotless:apply, fix compile and test

Signed-off-by: Andres Gomez Ferrer <[email protected]>

* Review fixes

Signed-off-by: Andres Gomez Ferrer <[email protected]>

Signed-off-by: Andres Gomez Ferrer <[email protected]>
Co-authored-by: Andres Gomez Ferrer <[email protected]>
  • Loading branch information
andresgomezfrr and andresgomezfrr authored Jan 24, 2023
1 parent f38be80 commit 57ad454
Show file tree
Hide file tree
Showing 12 changed files with 225 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.deser.DefaultDeserializationContext;
import com.fasterxml.jackson.databind.introspect.AnnotatedMember;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.fasterxml.jackson.module.paramnames.ParameterNamesModule;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Variable;
Expand All @@ -46,10 +48,13 @@ public class JacksonSdkType<T> extends SdkType<T> {

private final Class<T> clazz;
private final Map<String, Variable> variableMap;
private final Map<String, AnnotatedMember> membersMap;

private JacksonSdkType(Class<T> clazz, Map<String, Variable> variableMap) {
private JacksonSdkType(
Class<T> clazz, Map<String, Variable> variableMap, Map<String, AnnotatedMember> membersMap) {
this.clazz = Objects.requireNonNull(clazz);
this.variableMap = Objects.requireNonNull(variableMap);
this.membersMap = Objects.requireNonNull(membersMap);
}

public static <T> JacksonSdkType<T> of(Class<T> clazz) {
Expand All @@ -73,7 +78,7 @@ public static <T> JacksonSdkType<T> of(Class<T> clazz) {
serializer.acceptJsonFormatVisitor(
visitor, OBJECT_MAPPER.getTypeFactory().constructType(clazz));

return new JacksonSdkType<>(clazz, visitor.getVariableMap());
return new JacksonSdkType<>(clazz, visitor.getVariableMap(), visitor.getMembersMap());
} catch (JsonMappingException e) {
throw new IllegalArgumentException(
String.format("Failed to find serializer for [%s]", clazz.getName()), e);
Expand Down Expand Up @@ -120,6 +125,10 @@ public Map<String, Variable> getVariableMap() {
return variableMap;
}

private Map<String, AnnotatedMember> getMembersMap() {
return membersMap;
}

@Override
public T fromLiteralMap(Map<String, Literal> value) {
try {
Expand Down Expand Up @@ -168,6 +177,18 @@ public T promiseFor(String nodeId) {
}
}

@Override
public Map<String, SdkBindingData<?>> toSdkBindingMap(T value) {
return getMembersMap().entrySet().stream()
.map(
entry -> {
String attrName = entry.getKey();
AnnotatedMember member = entry.getValue();
return Map.entry(attrName, (SdkBindingData<?>) member.getValue(value));
})
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

private static ObjectMapper createObjectMapper(SdkTypeModule bindingMap) {
return new ObjectMapper()
.registerModule(bindingMap)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.introspect.AnnotatedMember;
import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonFormatVisitorWrapper;
import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonObjectFormatVisitor;
import java.util.Map;
Expand All @@ -44,4 +45,12 @@ public Map<String, Variable> getVariableMap() {

return builder.getVariableMap();
}

public Map<String, AnnotatedMember> getMembersMap() {
if (builder == null) {
throw new IllegalStateException("invariant failed: membersMap not set");
}

return builder.getMembersMap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.introspect.AnnotatedMember;
import com.fasterxml.jackson.databind.introspect.BeanPropertyDefinition;
import com.fasterxml.jackson.databind.jsonFormatVisitors.JsonObjectFormatVisitor;
import java.time.Duration;
Expand Down Expand Up @@ -62,6 +63,7 @@ class VariableMapVisitor extends JsonObjectFormatVisitor.Base {
}

private final Map<String, Variable> builder = new LinkedHashMap<>();
private final Map<String, AnnotatedMember> builderMembers = new LinkedHashMap<>();

@Override
public void property(BeanProperty prop) {
Expand All @@ -74,6 +76,7 @@ public void property(BeanProperty prop) {
prop.getMember().getMember().getDeclaringClass().getName());
Variable variable = Variable.builder().description("").literalType(literalType).build();

builderMembers.put(prop.getName(), prop.getMember());
builder.put(prop.getName(), variable);
}

Expand Down Expand Up @@ -107,6 +110,10 @@ public Map<String, Variable> getVariableMap() {
return unmodifiableMap(new HashMap<>(builder));
}

public Map<String, AnnotatedMember> getMembersMap() {
return unmodifiableMap(new HashMap<>(builderMembers));
}

@SuppressWarnings("AlreadyChecked")
private LiteralType toLiteralType(
JavaType javaType, boolean rootLevel, String propName, String declaringClassName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.util.StdConverter;
Expand Down Expand Up @@ -283,6 +286,73 @@ void testToLiteralMap() {
)));
}

@Test
public void testToSdkBindingDataMap() {
AutoValueInput input =
createAutoValueInput(
/* i= */ 42L,
/* f= */ 42.0d,
/* s= */ "42",
/* b= */ false,
/* t= */ Instant.ofEpochSecond(42, 1),
/* d= */ Duration.ofSeconds(1, 42),
/// * blob= */ blob,
/* l= */ List.of("foo"),
/* m= */ Map.of("marco", "polo"),
/* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")),
/* lm= */ List.of(Map.of("A", "a", "B", "b"), Map.of("a", "A", "b", "B")),
/* ml= */ Map.of("frodo", List.of("baggins", "bolson")),
/* mm= */ Map.of(
"math", Map.of("pi", "3.14", "e", "2.72"), "pokemon", Map.of("ash", "pikachu")));

Map<String, SdkBindingData<?>> sdkBindingDataMap =
JacksonSdkType.of(AutoValueInput.class).toSdkBindingMap(input);

Map<String, SdkBindingData<?>> expected = new HashMap<>();
expected.put("i", input.i());
expected.put("f", input.f());
expected.put("s", input.s());
expected.put("b", input.b());
expected.put("t", input.t());
expected.put("d", input.d());
expected.put("l", input.l());
expected.put("m", input.m());
expected.put("ll", input.ll());
expected.put("lm", input.lm());
expected.put("ml", input.ml());
expected.put("mm", input.mm());

assertEquals(expected, sdkBindingDataMap);
}

@Test
public void testToSdkBindingDataMapJsonProperties() {

JsonPropertyClassInput input =
new JsonPropertyClassInput(
SdkBindingData.ofString("test"), SdkBindingData.ofString("name"));

Map<String, SdkBindingData<?>> sdkBindingDataMap =
JacksonSdkType.of(JsonPropertyClassInput.class).toSdkBindingMap(input);

var expected = Map.of("test", input.test, "name", input.otherTest);

assertEquals(expected, sdkBindingDataMap);
}

public static class JsonPropertyClassInput {
@JsonProperty SdkBindingData<String> test;

@JsonProperty("name")
SdkBindingData<String> otherTest;

@JsonCreator
public JsonPropertyClassInput(SdkBindingData<String> test, SdkBindingData<String> otherTest) {
this.test = test;
this.otherTest = otherTest;
}
}

@Test
public void testPojoToLiteralMap() {
PojoInput input = new PojoInput();
Expand Down
2 changes: 2 additions & 0 deletions flytekit-java/src/main/java/org/flyte/flytekit/SdkType.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ public abstract class SdkType<T> {
public abstract T promiseFor(String nodeId);

public abstract Map<String, Variable> getVariableMap();

public abstract Map<String, SdkBindingData<?>> toSdkBindingMap(T value);
}
5 changes: 5 additions & 0 deletions flytekit-java/src/main/java/org/flyte/flytekit/SdkTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,10 @@ public Void promiseFor(String nodeId) {
public Map<String, Variable> getVariableMap() {
return Collections.emptyMap();
}

@Override
public Map<String, SdkBindingData<?>> toSdkBindingMap(Void value) {
return Collections.emptyMap();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,40 @@ public static TestPairIntegerInput create(SdkBindingData<Long> a, SdkBindingData

public static class SdkType extends org.flyte.flytekit.SdkType<TestPairIntegerInput> {

private static final String A = "a";
private static final String B = "b";

@Override
public Map<String, Literal> toLiteralMap(TestPairIntegerInput value) {
return Map.of(
"a", Literals.ofInteger(value.a().get()),
"b", Literals.ofInteger(value.b().get()));
A, Literals.ofInteger(value.a().get()),
B, Literals.ofInteger(value.b().get()));
}

@Override
public TestPairIntegerInput fromLiteralMap(Map<String, Literal> value) {
return create(
SdkBindingData.ofInteger(value.get("a").scalar().primitive().integerValue()),
SdkBindingData.ofInteger(value.get("b").scalar().primitive().integerValue()));
SdkBindingData.ofInteger(value.get(A).scalar().primitive().integerValue()),
SdkBindingData.ofInteger(value.get(B).scalar().primitive().integerValue()));
}

@Override
public TestPairIntegerInput promiseFor(String nodeId) {
return create(
SdkBindingData.ofOutputReference(nodeId, "a", LiteralTypes.INTEGER),
SdkBindingData.ofOutputReference(nodeId, "b", LiteralTypes.INTEGER));
SdkBindingData.ofOutputReference(nodeId, A, LiteralTypes.INTEGER),
SdkBindingData.ofOutputReference(nodeId, B, LiteralTypes.INTEGER));
}

@Override
public Map<String, Variable> getVariableMap() {
return Map.of(
"a", Variable.builder().literalType(LiteralTypes.INTEGER).build(),
"b", Variable.builder().literalType(LiteralTypes.INTEGER).build());
A, Variable.builder().literalType(LiteralTypes.INTEGER).build(),
B, Variable.builder().literalType(LiteralTypes.INTEGER).build());
}

@Override
public Map<String, SdkBindingData<?>> toSdkBindingMap(TestPairIntegerInput value) {
return Map.of(A, value.a(), B, value.b());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,10 @@ public TestUnaryBooleanOutput promiseFor(String nodeId) {
public Map<String, Variable> getVariableMap() {
return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build());
}

@Override
public Map<String, SdkBindingData<?>> toSdkBindingMap(TestUnaryBooleanOutput value) {
return Map.of(VAR, value.o());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,10 @@ public TestUnaryIntegerOutput promiseFor(String nodeId) {
public Map<String, Variable> getVariableMap() {
return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build());
}

@Override
public Map<String, SdkBindingData<?>> toSdkBindingMap(TestUnaryIntegerOutput value) {
return Map.of(VAR, value.o());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,10 @@ public TestUnaryIntegerOutput promiseFor(String nodeId) {
public Map<String, Variable> getVariableMap() {
return Map.of(VAR, Variable.builder().literalType(LITERAL_TYPE).build());
}

@Override
public Map<String, SdkBindingData<?>> toSdkBindingMap(TestUnaryIntegerOutput value) {
return Map.of(VAR, value.o());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.flyte.flytekitscala.SdkBindingData._
import org.junit.Assert.assertEquals
import org.junit.Test
import org.flyte.examples.AllInputsTask.AutoAllInputsInput
import org.junit.jupiter.api.Assertions.assertThrows

class SdkScalaTypeTest {

Expand Down Expand Up @@ -177,6 +178,49 @@ class SdkScalaTypeTest {
assertEquals(expected, output)
}

@Test
def testToSdkBindingMap(): Unit = {
val input = ScalarInput(
string = ofString("string"),
integer = ofInteger(1337L),
float = ofFloat(42.0),
boolean = ofBoolean(true),
datetime = ofDateTime(Instant.ofEpochMilli(123456L)),
duration = ofDuration(Duration.ofSeconds(123, 456))
)

val output = SdkScalaType[ScalarInput].toSdkBindingMap(input)

val expected = Map(
"string" -> ofString("string"),
"integer" -> ofInteger(1337L),
"float" -> ofFloat(42.0),
"boolean" -> ofBoolean(true),
"datetime" -> ofDateTime(Instant.ofEpochMilli(123456L)),
"duration" -> ofDuration(Duration.ofSeconds(123, 456))
).asJava

assertEquals(expected, output)
}

case class InputWithoutSdkBinding(notSdkBinding: String)
@Test
def testCaseClassWithoutSdkBindingData(): Unit = {
val exception = assertThrows(
classOf[IllegalStateException],
() => {
SdkScalaType[InputWithoutSdkBinding].toSdkBindingMap(
InputWithoutSdkBinding("test")
)
}
)

assertEquals(
"All the fields of the case class InputWithoutSdkBinding must be SdkBindingData[_]",
exception.getMessage
)
}

private def createCollectionVar(simpleType: SimpleType) = {
Variable
.builder()
Expand Down
Loading

0 comments on commit 57ad454

Please sign in to comment.