diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry
index 7cc5459dd..acd3fc633 100644
--- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry
+++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry
@@ -1 +1 @@
-org.flyte.examples.flytekitscala.FibonacciLaunchPlan
+org.flyte.examples.flytekitscala.LaunchPlanRegistry
diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask
index 0fc19c133..508e6cb51 100644
--- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask
+++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask
@@ -3,3 +3,4 @@ org.flyte.examples.flytekitscala.SumTask
org.flyte.examples.flytekitscala.GreetTask
org.flyte.examples.flytekitscala.AddQuestionTask
org.flyte.examples.flytekitscala.NoInputsTask
+org.flyte.examples.flytekitscala.NestedIOTask
diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow
index 9b9ca9038..844fdc040 100644
--- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow
+++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow
@@ -1,2 +1,3 @@
org.flyte.examples.flytekitscala.FibonacciWorkflow
org.flyte.examples.flytekitscala.WelcomeWorkflow
+org.flyte.examples.flytekitscala.NestedIOWorkflow
diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciLaunchPlan.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala
similarity index 67%
rename from flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciLaunchPlan.scala
rename to flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala
index 4fd16f3b0..41bf0e7cd 100644
--- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciLaunchPlan.scala
+++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala
@@ -20,8 +20,9 @@ import org.flyte.flytekit.{SdkLaunchPlan, SimpleSdkLaunchPlanRegistry}
import org.flyte.flytekitscala.SdkScalaType
case class FibonacciLaunchPlanInput(fib0: Long, fib1: Long)
+case class NestedIOLaunchPlanInput(name: String, generic: Nested)
-class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry {
+class LaunchPlanRegistry extends SimpleSdkLaunchPlanRegistry {
// Register default launch plans for all workflows
registerDefaultLaunchPlans()
@@ -53,4 +54,33 @@ class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry {
.withDefaultInput("fib0", 0L)
.withDefaultInput("fib1", 1L)
)
+
+ registerLaunchPlan(
+ SdkLaunchPlan
+ .of(new NestedIOWorkflow)
+ .withName("NestedIOWorkflowLaunchPlan")
+ .withDefaultInput(
+ SdkScalaType[NestedIOLaunchPlanInput],
+ NestedIOLaunchPlanInput(
+ "yo",
+ Nested(
+ boolean = true,
+ 1.toByte,
+ 2.toShort,
+ 3,
+ 4L,
+ 5.toFloat,
+ 6.toDouble,
+ "hello",
+ List("1", "2"),
+ Map("1" -> "1", "2" -> "2"),
+ Some(false),
+ None,
+ Some(List("3", "4")),
+ Some(Map("3" -> "3", "4" -> "4")),
+ NestedNested(7.toDouble, NestedNestedNested("world"))
+ )
+ )
+ )
+ )
}
diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala
new file mode 100644
index 000000000..3a488adbc
--- /dev/null
+++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala
@@ -0,0 +1,83 @@
+/*
+ * Copyright 2023 Flyte Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.flyte.examples.flytekitscala
+
+import org.flyte.flytekit.{SdkBindingData, SdkRunnableTask, SdkTransform}
+import org.flyte.flytekitscala.{
+ Description,
+ SdkBindingDataFactory,
+ SdkScalaType
+}
+
+case class NestedNestedNested(string: String)
+case class NestedNested(double: Double, nested: NestedNestedNested)
+case class Nested(
+ boolean: Boolean,
+ byte: Byte,
+ short: Short,
+ int: Int,
+ long: Long,
+ float: Float,
+ double: Double,
+ string: String,
+ list: List[String],
+ map: Map[String, String],
+ optBoolean: Option[Boolean],
+ optByte: Option[Byte],
+ optList: Option[List[String]],
+ optMap: Option[Map[String, String]],
+ nested: NestedNested
+)
+case class NestedIOTaskInput(
+ @Description("the name of the person to be greeted")
+ name: SdkBindingData[String],
+ @Description("a nested input")
+ generic: SdkBindingData[Nested]
+)
+case class NestedIOTaskOutput(
+ @Description("the name of the person to be greeted")
+ name: SdkBindingData[String],
+ @Description("a nested input")
+ generic: SdkBindingData[Nested]
+)
+
+/** Example Flyte task that takes a name as the input and outputs a simple
+ * greeting message.
+ */
+class NestedIOTask
+ extends SdkRunnableTask[
+ NestedIOTaskInput,
+ NestedIOTaskOutput
+ ](
+ SdkScalaType[NestedIOTaskInput],
+ SdkScalaType[NestedIOTaskOutput]
+ ) {
+
+ /** Defines task behavior. This task takes a name as the input, wraps it in a
+ * welcome message, and outputs the message.
+ *
+ * @param input
+ * the name of the person to be greeted
+ * @return
+ * the welcome message
+ */
+ override def run(input: NestedIOTaskInput): NestedIOTaskOutput =
+ NestedIOTaskOutput(
+ input.name,
+ input.generic
+ )
+}
diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala
new file mode 100644
index 000000000..dfe996650
--- /dev/null
+++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2020-2023 Flyte Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.flyte.examples.flytekitscala
+
+import org.flyte.flytekitscala.{
+ SdkScalaType,
+ SdkScalaWorkflow,
+ SdkScalaWorkflowBuilder
+}
+
+class NestedIOWorkflow
+ extends SdkScalaWorkflow[NestedIOTaskInput, Unit](
+ SdkScalaType[NestedIOTaskInput],
+ SdkScalaType.unit
+ ) {
+
+ override def expand(
+ builder: SdkScalaWorkflowBuilder,
+ input: NestedIOTaskInput
+ ): Unit = {
+ builder.apply(new NestedIOTask(), input)
+ }
+}
diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala
index 1fc5060e5..6d53bceae 100644
--- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala
+++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala
@@ -24,7 +24,18 @@ import org.flyte.flytekit.{
import java.time.{Duration, Instant}
import scala.collection.JavaConverters._
-import scala.reflect.runtime.universe.{TypeTag, typeOf}
+import scala.reflect.api.{Mirror, TypeCreator, Universe}
+import scala.reflect.runtime.universe
+import scala.reflect.{ClassTag, classTag}
+import scala.reflect.runtime.universe.{
+ NoPrefix,
+ Symbol,
+ Type,
+ TypeTag,
+ runtimeMirror,
+ termNames,
+ typeOf
+}
object SdkLiteralTypes {
@@ -202,6 +213,185 @@ object SdkLiteralTypes {
*/
def durations(): SdkLiteralType[Duration] = SdkJavaLiteralTypes.durations()
+ /** Returns a [[SdkLiteralType]] for products.
+ * @return
+ * the [[SdkLiteralType]]
+ */
+ def generics[T <: Product: TypeTag: ClassTag](): SdkLiteralType[T] = {
+ ScalaLiteralType[T](
+ LiteralType.ofSimpleType(SimpleType.STRUCT),
+ (value: T) => Literal.ofScalar(Scalar.ofGeneric(toStruct(value))),
+ (x: Literal) => toProduct(x.scalar().generic()),
+ (v: T) => BindingData.ofScalar(Scalar.ofGeneric(toStruct(v))),
+ "generics"
+ )
+ }
+
+ private def toStruct(product: Product): Struct = {
+ def productToMap(product: Product): Map[String, Any] = {
+ // by spec getDeclaredFields is not ordered but in practice it works fine
+ // it's a lot better since Scala 2.13 because productElementNames was introduced
+ // (product.productElementNames zip product.productIterator).toMap
+ product.getClass.getDeclaredFields
+ .map(_.getName)
+ .zip(product.productIterator.toList)
+ .toMap
+ }
+
+ def mapToStruct(map: Map[String, Any]): Struct = {
+ val fields = map.map({ case (key, value) =>
+ (key, anyToStructValue(value))
+ })
+ Struct.of(fields.asJava)
+ }
+
+ def anyToStructValue(value: Any): Struct.Value = {
+ def anyToStructureValue0(value: Any): Struct.Value = {
+ value match {
+ case s: String => Struct.Value.ofStringValue(s)
+ case n @ (_: Byte | _: Short | _: Int | _: Long | _: Float |
+ _: Double) =>
+ Struct.Value.ofNumberValue(n.toString.toDouble)
+ case b: Boolean => Struct.Value.ofBoolValue(b)
+ case l: List[Any] =>
+ Struct.Value.ofListValue(l.map(anyToStructValue).asJava)
+ case m: Map[_, _] =>
+ Struct.Value.ofStructValue(
+ mapToStruct(m.asInstanceOf[Map[String, Any]])
+ )
+ case null => Struct.Value.ofNullValue()
+ case p: Product =>
+ Struct.Value.ofStructValue(mapToStruct(productToMap(p)))
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Unsupported type: ${value.getClass}"
+ )
+ }
+ }
+
+ value match {
+ case Some(v) => anyToStructureValue0(v)
+ case None => Struct.Value.ofNullValue()
+ case _ => anyToStructureValue0(value)
+ }
+ }
+
+ mapToStruct(productToMap(product))
+ }
+
+ private def toProduct[T <: Product: TypeTag: ClassTag](
+ struct: Struct
+ ): T = {
+ def structToMap(struct: Struct): Map[String, Any] = {
+ struct
+ .fields()
+ .asScala
+ .map({ case (key, value) =>
+ (key, structValueToAny(value))
+ })
+ .toMap
+ }
+
+ def mapToProduct[S <: Product: TypeTag: ClassTag](
+ map: Map[String, Any]
+ ): S = {
+ val mirror = runtimeMirror(classTag[S].runtimeClass.getClassLoader)
+
+ def valueToParamValue(value: Any, param: Symbol): Any = {
+ def valueToParamValue0(value: Any, param: Symbol): Any = {
+ if (param.typeSignature =:= typeOf[Byte]) {
+ value.asInstanceOf[Double].toByte
+ } else if (param.typeSignature =:= typeOf[Short]) {
+ value.asInstanceOf[Double].toShort
+ } else if (param.typeSignature =:= typeOf[Int]) {
+ value.asInstanceOf[Double].toInt
+ } else if (param.typeSignature =:= typeOf[Long]) {
+ value.asInstanceOf[Double].toLong
+ } else if (param.typeSignature =:= typeOf[Float]) {
+ value.asInstanceOf[Double].toFloat
+ } else if (param.typeSignature <:< typeOf[Product]) {
+ val typeTag = createTypeTag(param.typeSignature)
+ val classTag = ClassTag(
+ typeTag.mirror.runtimeClass(param.typeSignature)
+ )
+ mapToProduct(value.asInstanceOf[Map[String, Any]])(
+ typeTag,
+ classTag
+ )
+ } else {
+ value
+ }
+ }
+
+ if (param.typeSignature <:< typeOf[Option[Any]]) {
+ Some(
+ valueToParamValue0(
+ value,
+ param.typeSignature.dealias.typeArgs.head.typeSymbol
+ )
+ )
+ } else {
+ valueToParamValue0(value, param)
+ }
+ }
+
+ def createTypeTag[U <: Product](tpe: Type): TypeTag[U] = {
+ val typSym = mirror.staticClass(tpe.typeSymbol.fullName)
+ // note: this uses internal API, otherwise we will need to depend on scala-compiler at runtime
+ val typeRef =
+ universe.internal.typeRef(NoPrefix, typSym, List.empty)
+
+ TypeTag(
+ mirror,
+ new TypeCreator {
+ override def apply[V <: Universe with Singleton](
+ m: Mirror[V]
+ ): V#Type = {
+ assert(
+ m == mirror,
+ s"TypeTag[$typeRef] defined in $mirror cannot be migrated to $m."
+ )
+ typeRef.asInstanceOf[V#Type]
+ }
+ }
+ )
+ }
+
+ val clazz = typeOf[S].typeSymbol.asClass
+ val classMirror = mirror.reflectClass(clazz)
+ val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod
+ val constructorMirror = classMirror.reflectConstructor(constructor)
+
+ val constructorArgs =
+ constructor.paramLists.flatten.map((param: Symbol) => {
+ val paramName = param.name.toString
+ val value = map.getOrElse(
+ paramName,
+ throw new IllegalArgumentException(
+ s"Map is missing required parameter named $paramName"
+ )
+ )
+ valueToParamValue(value, param)
+ })
+
+ constructorMirror(constructorArgs: _*).asInstanceOf[S]
+ }
+
+ def structValueToAny(value: Struct.Value): Any = {
+ value.kind() match {
+ case Struct.Value.Kind.STRING_VALUE => value.stringValue()
+ case Struct.Value.Kind.NUMBER_VALUE => value.numberValue()
+ case Struct.Value.Kind.BOOL_VALUE => value.boolValue()
+ case Struct.Value.Kind.LIST_VALUE =>
+ value.listValue().asScala.map(structValueToAny).toList
+ case Struct.Value.Kind.STRUCT_VALUE => structToMap(value.structValue())
+ case Struct.Value.Kind.NULL_VALUE => None
+ }
+ }
+
+ mapToProduct[T](structToMap(struct))
+ }
+
/** Returns a [[SdkLiteralType]] for blob.
*
* @return
diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala
index 25f03122a..b64e0ae71 100644
--- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala
+++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala
@@ -30,6 +30,8 @@ import org.flyte.flytekit.{
import scala.annotation.implicitNotFound
import scala.collection.JavaConverters._
+import scala.reflect.{ClassTag, classTag}
+import scala.reflect.runtime.universe.{TypeTag, typeOf}
/** Type class to map between Flyte `Variable` and `Literal` and Scala case
* classes.
@@ -231,6 +233,12 @@ object SdkScalaType {
implicit def durationLiteralType: SdkScalaLiteralType[Duration] =
DelegateLiteralType(SdkLiteralTypes.durations())
+ // fixme: using Product is just an approximation for case class because Product
+ // is also super class of, for example, Option and Tuple
+ implicit def productLiteralType[T <: Product: TypeTag: ClassTag]
+ : SdkScalaLiteralType[T] =
+ DelegateLiteralType(SdkLiteralTypes.generics())
+
// fixme: create blob type from annotation, or rethink how we could offer the offloaded data feature
// https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype
implicit def blobLiteralType: SdkScalaLiteralType[Blob] =
diff --git a/integration-tests/src/test/java/org/flyte/AdditionalIT.java b/integration-tests/src/test/java/org/flyte/AdditionalIT.java
index 00e50c27a..5355ddb71 100644
--- a/integration-tests/src/test/java/org/flyte/AdditionalIT.java
+++ b/integration-tests/src/test/java/org/flyte/AdditionalIT.java
@@ -16,25 +16,21 @@
*/
package org.flyte;
-import static org.flyte.FlyteContainer.CLIENT;
+import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import flyteidl.core.Literals;
+import flyteidl.core.Literals.LiteralMap;
import org.flyte.utils.Literal;
-import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.testcontainers.shaded.com.google.common.collect.ImmutableMap;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class AdditionalIT {
- @BeforeAll
- public static void beforeAll() {
- CLIENT.registerWorkflows("integration-tests/target/lib");
- }
-
+class AdditionalIT extends Fixtures {
@ParameterizedTest
@CsvSource({
"0,0,0,0,a == b && c == d",
@@ -47,7 +43,7 @@ public static void beforeAll() {
"0,1,0,1,a < b && c < d",
"1,0,0,1,a > b && c < d",
})
- public void testBranchNodeWorkflow(long a, long b, long c, long d, String expected) {
+ void testBranchNodeWorkflow(long a, long b, long c, long d, String expected) {
Literals.LiteralMap output =
CLIENT.createExecution(
"org.flyte.integrationtests.BranchNodeWorkflow",
@@ -66,7 +62,7 @@ public void testBranchNodeWorkflow(long a, long b, long c, long d, String expect
"table-exists,true",
"non-existent,false",
})
- public void testStructs(String name, boolean expected) {
+ void testStructs(String name, boolean expected) {
Literals.LiteralMap output =
CLIENT.createExecution(
"org.flyte.integrationtests.structs.MockPipelineWorkflow",
@@ -74,4 +70,12 @@ public void testStructs(String name, boolean expected) {
assertThat(output, equalTo(Literal.ofBooleanMap(ImmutableMap.of("exists", expected))));
}
+
+ @Test
+ void testStructsScala() {
+ Literals.LiteralMap output =
+ CLIENT.createExecution("NestedIOWorkflowLaunchPlan", STAGING_DOMAIN);
+
+ assertThat(output, equalTo(LiteralMap.getDefaultInstance()));
+ }
}
diff --git a/integration-tests/src/test/java/org/flyte/FlyteContainer.java b/integration-tests/src/test/java/org/flyte/Fixtures.java
similarity index 60%
rename from integration-tests/src/test/java/org/flyte/FlyteContainer.java
rename to integration-tests/src/test/java/org/flyte/Fixtures.java
index a95fe7d76..d80bec77d 100644
--- a/integration-tests/src/test/java/org/flyte/FlyteContainer.java
+++ b/integration-tests/src/test/java/org/flyte/Fixtures.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2020-2022 Flyte Authors.
+ * Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,8 +16,16 @@
*/
package org.flyte;
+import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN;
+
import org.flyte.utils.FlyteSandboxClient;
-public class FlyteContainer {
- static final FlyteSandboxClient CLIENT = FlyteSandboxClient.create();
+class Fixtures {
+ protected static final FlyteSandboxClient CLIENT = FlyteSandboxClient.create();
+
+ static {
+ CLIENT.registerWorkflows("integration-tests/target/lib");
+ CLIENT.registerWorkflows("flytekit-examples/target/lib");
+ CLIENT.registerWorkflows("flytekit-examples-scala/target/lib", STAGING_DOMAIN);
+ }
}
diff --git a/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java b/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java
index cbce85dee..c4ec8151e 100644
--- a/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java
+++ b/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java
@@ -16,29 +16,17 @@
*/
package org.flyte;
-import static org.flyte.FlyteContainer.CLIENT;
-import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN;
import static org.flyte.utils.Literal.ofIntegerMap;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import flyteidl.core.Literals;
-import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.testcontainers.shaded.com.google.common.collect.ImmutableMap;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class JavaExamplesIT {
- private static final String CLASSPATH_EXAMPLES = "flytekit-examples/target/lib";
- private static final String CLASSPATH_EXAMPLES_SCALA = "flytekit-examples-scala/target/lib";
-
- @BeforeAll
- public static void beforeAll() {
- CLIENT.registerWorkflows(CLASSPATH_EXAMPLES);
- CLIENT.registerWorkflows(CLASSPATH_EXAMPLES_SCALA, STAGING_DOMAIN);
- }
-
+class JavaExamplesIT extends Fixtures {
@Test
public void testSumTask() {
Literals.LiteralMap output =
@@ -53,7 +41,7 @@ public void testSumTask() {
}
@Test
- public void testFibonacciWorkflow() {
+ void testFibonacciWorkflow() {
Literals.LiteralMap output =
CLIENT.createExecution(
"org.flyte.examples.FibonacciWorkflow",
@@ -66,7 +54,7 @@ public void testFibonacciWorkflow() {
}
@Test
- public void testDynamicFibonacciWorkflow() {
+ void testDynamicFibonacciWorkflow() {
Literals.LiteralMap output =
CLIENT.createExecution(
"org.flyte.examples.DynamicFibonacciWorkflow", ofIntegerMap(ImmutableMap.of("n", 2L)));
diff --git a/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java b/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java
index 9d9caac47..888416893 100644
--- a/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java
+++ b/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java
@@ -16,7 +16,6 @@
*/
package org.flyte;
-import static org.flyte.FlyteContainer.CLIENT;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
@@ -29,13 +28,13 @@
import org.junit.jupiter.api.io.TempDir;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class SerializeJavaIT {
+class SerializeJavaIT extends Fixtures {
private static final String CLASSPATH = "flytekit-examples/target/lib";
@TempDir Path managed;
@Test
- public void testSerializeWorkflows() {
+ void testSerializeWorkflows() {
try {
File current = new File("target/protos");
File tempDir = managed.resolve(current.getAbsolutePath()).toFile();
diff --git a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java
index f614ef596..a2c04d119 100644
--- a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java
+++ b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java
@@ -22,7 +22,9 @@
import flyteidl.admin.ExecutionOuterClass;
import flyteidl.core.Execution;
import flyteidl.core.IdentifierOuterClass;
+import flyteidl.core.IdentifierOuterClass.ResourceType;
import flyteidl.core.Literals;
+import flyteidl.core.Literals.LiteralMap;
import flyteidl.service.AdminServiceGrpc;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
@@ -55,6 +57,19 @@ public static FlyteSandboxClient create() {
return new FlyteSandboxClient(version, stub);
}
+ public Literals.LiteralMap createExecution(String name, String domain) {
+ return createExecution(
+ IdentifierOuterClass.Identifier.newBuilder()
+ .setResourceType(ResourceType.LAUNCH_PLAN)
+ .setDomain(domain)
+ .setProject(PROJECT)
+ .setName(name)
+ .setVersion(version)
+ .build(),
+ LiteralMap.getDefaultInstance(),
+ domain);
+ }
+
public Literals.LiteralMap createTaskExecution(String name, Literals.LiteralMap inputs) {
return createExecution(
IdentifierOuterClass.Identifier.newBuilder()
@@ -64,7 +79,8 @@ public Literals.LiteralMap createTaskExecution(String name, Literals.LiteralMap
.setName(name)
.setVersion(version)
.build(),
- inputs);
+ inputs,
+ DEVELOPMENT_DOMAIN);
}
public Literals.LiteralMap createExecution(String name, Literals.LiteralMap inputs) {
@@ -76,15 +92,16 @@ public Literals.LiteralMap createExecution(String name, Literals.LiteralMap inpu
.setName(name)
.setVersion(version)
.build(),
- inputs);
+ inputs,
+ DEVELOPMENT_DOMAIN);
}
private Literals.LiteralMap createExecution(
- IdentifierOuterClass.Identifier id, Literals.LiteralMap inputs) {
+ IdentifierOuterClass.Identifier id, Literals.LiteralMap inputs, String domain) {
ExecutionOuterClass.ExecutionCreateResponse response =
stub.createExecution(
ExecutionOuterClass.ExecutionCreateRequest.newBuilder()
- .setDomain(DEVELOPMENT_DOMAIN)
+ .setDomain(domain)
.setProject(PROJECT)
.setInputs(inputs)
.setSpec(ExecutionOuterClass.ExecutionSpec.newBuilder().setLaunchPlan(id).build())
diff --git a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxContainer.java b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxContainer.java
index c64c948bb..85561ed69 100644
--- a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxContainer.java
+++ b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxContainer.java
@@ -58,6 +58,8 @@ private static void startContainer() {
IOUtils.copy(imageInputStream, outputStream);
}
+ Thread.sleep(1000);
+
ExecResult execResult =
INSTANCE.execInContainer(
"docker", "load", "-i", "integration-tests/target/jflyte.tar.gz");
diff --git a/pom.xml b/pom.xml
index 0b6208c8e..160e4fca9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -374,7 +374,7 @@
net.java.dev.jna
jna
- 5.8.0
+ 5.9.0
com.fasterxml.jackson