diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry index 7cc5459dd..acd3fc633 100644 --- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry +++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkLaunchPlanRegistry @@ -1 +1 @@ -org.flyte.examples.flytekitscala.FibonacciLaunchPlan +org.flyte.examples.flytekitscala.LaunchPlanRegistry diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask index 0fc19c133..508e6cb51 100644 --- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask +++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask @@ -3,3 +3,4 @@ org.flyte.examples.flytekitscala.SumTask org.flyte.examples.flytekitscala.GreetTask org.flyte.examples.flytekitscala.AddQuestionTask org.flyte.examples.flytekitscala.NoInputsTask +org.flyte.examples.flytekitscala.NestedIOTask diff --git a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow index 9b9ca9038..844fdc040 100644 --- a/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow +++ b/flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkWorkflow @@ -1,2 +1,3 @@ org.flyte.examples.flytekitscala.FibonacciWorkflow org.flyte.examples.flytekitscala.WelcomeWorkflow +org.flyte.examples.flytekitscala.NestedIOWorkflow diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciLaunchPlan.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala similarity index 67% rename from flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciLaunchPlan.scala rename to flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala index 4fd16f3b0..41bf0e7cd 100644 --- a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/FibonacciLaunchPlan.scala +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala @@ -20,8 +20,9 @@ import org.flyte.flytekit.{SdkLaunchPlan, SimpleSdkLaunchPlanRegistry} import org.flyte.flytekitscala.SdkScalaType case class FibonacciLaunchPlanInput(fib0: Long, fib1: Long) +case class NestedIOLaunchPlanInput(name: String, generic: Nested) -class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry { +class LaunchPlanRegistry extends SimpleSdkLaunchPlanRegistry { // Register default launch plans for all workflows registerDefaultLaunchPlans() @@ -53,4 +54,33 @@ class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry { .withDefaultInput("fib0", 0L) .withDefaultInput("fib1", 1L) ) + + registerLaunchPlan( + SdkLaunchPlan + .of(new NestedIOWorkflow) + .withName("NestedIOWorkflowLaunchPlan") + .withDefaultInput( + SdkScalaType[NestedIOLaunchPlanInput], + NestedIOLaunchPlanInput( + "yo", + Nested( + boolean = true, + 1.toByte, + 2.toShort, + 3, + 4L, + 5.toFloat, + 6.toDouble, + "hello", + List("1", "2"), + Map("1" -> "1", "2" -> "2"), + Some(false), + None, + Some(List("3", "4")), + Some(Map("3" -> "3", "4" -> "4")), + NestedNested(7.toDouble, NestedNestedNested("world")) + ) + ) + ) + ) } diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala new file mode 100644 index 000000000..3a488adbc --- /dev/null +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala @@ -0,0 +1,83 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.examples.flytekitscala + +import org.flyte.flytekit.{SdkBindingData, SdkRunnableTask, SdkTransform} +import org.flyte.flytekitscala.{ + Description, + SdkBindingDataFactory, + SdkScalaType +} + +case class NestedNestedNested(string: String) +case class NestedNested(double: Double, nested: NestedNestedNested) +case class Nested( + boolean: Boolean, + byte: Byte, + short: Short, + int: Int, + long: Long, + float: Float, + double: Double, + string: String, + list: List[String], + map: Map[String, String], + optBoolean: Option[Boolean], + optByte: Option[Byte], + optList: Option[List[String]], + optMap: Option[Map[String, String]], + nested: NestedNested +) +case class NestedIOTaskInput( + @Description("the name of the person to be greeted") + name: SdkBindingData[String], + @Description("a nested input") + generic: SdkBindingData[Nested] +) +case class NestedIOTaskOutput( + @Description("the name of the person to be greeted") + name: SdkBindingData[String], + @Description("a nested input") + generic: SdkBindingData[Nested] +) + +/** Example Flyte task that takes a name as the input and outputs a simple + * greeting message. + */ +class NestedIOTask + extends SdkRunnableTask[ + NestedIOTaskInput, + NestedIOTaskOutput + ]( + SdkScalaType[NestedIOTaskInput], + SdkScalaType[NestedIOTaskOutput] + ) { + + /** Defines task behavior. This task takes a name as the input, wraps it in a + * welcome message, and outputs the message. + * + * @param input + * the name of the person to be greeted + * @return + * the welcome message + */ + override def run(input: NestedIOTaskInput): NestedIOTaskOutput = + NestedIOTaskOutput( + input.name, + input.generic + ) +} diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala new file mode 100644 index 000000000..dfe996650 --- /dev/null +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2020-2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.examples.flytekitscala + +import org.flyte.flytekitscala.{ + SdkScalaType, + SdkScalaWorkflow, + SdkScalaWorkflowBuilder +} + +class NestedIOWorkflow + extends SdkScalaWorkflow[NestedIOTaskInput, Unit]( + SdkScalaType[NestedIOTaskInput], + SdkScalaType.unit + ) { + + override def expand( + builder: SdkScalaWorkflowBuilder, + input: NestedIOTaskInput + ): Unit = { + builder.apply(new NestedIOTask(), input) + } +} diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala index 1fc5060e5..6d53bceae 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala @@ -24,7 +24,18 @@ import org.flyte.flytekit.{ import java.time.{Duration, Instant} import scala.collection.JavaConverters._ -import scala.reflect.runtime.universe.{TypeTag, typeOf} +import scala.reflect.api.{Mirror, TypeCreator, Universe} +import scala.reflect.runtime.universe +import scala.reflect.{ClassTag, classTag} +import scala.reflect.runtime.universe.{ + NoPrefix, + Symbol, + Type, + TypeTag, + runtimeMirror, + termNames, + typeOf +} object SdkLiteralTypes { @@ -202,6 +213,185 @@ object SdkLiteralTypes { */ def durations(): SdkLiteralType[Duration] = SdkJavaLiteralTypes.durations() + /** Returns a [[SdkLiteralType]] for products. + * @return + * the [[SdkLiteralType]] + */ + def generics[T <: Product: TypeTag: ClassTag](): SdkLiteralType[T] = { + ScalaLiteralType[T]( + LiteralType.ofSimpleType(SimpleType.STRUCT), + (value: T) => Literal.ofScalar(Scalar.ofGeneric(toStruct(value))), + (x: Literal) => toProduct(x.scalar().generic()), + (v: T) => BindingData.ofScalar(Scalar.ofGeneric(toStruct(v))), + "generics" + ) + } + + private def toStruct(product: Product): Struct = { + def productToMap(product: Product): Map[String, Any] = { + // by spec getDeclaredFields is not ordered but in practice it works fine + // it's a lot better since Scala 2.13 because productElementNames was introduced + // (product.productElementNames zip product.productIterator).toMap + product.getClass.getDeclaredFields + .map(_.getName) + .zip(product.productIterator.toList) + .toMap + } + + def mapToStruct(map: Map[String, Any]): Struct = { + val fields = map.map({ case (key, value) => + (key, anyToStructValue(value)) + }) + Struct.of(fields.asJava) + } + + def anyToStructValue(value: Any): Struct.Value = { + def anyToStructureValue0(value: Any): Struct.Value = { + value match { + case s: String => Struct.Value.ofStringValue(s) + case n @ (_: Byte | _: Short | _: Int | _: Long | _: Float | + _: Double) => + Struct.Value.ofNumberValue(n.toString.toDouble) + case b: Boolean => Struct.Value.ofBoolValue(b) + case l: List[Any] => + Struct.Value.ofListValue(l.map(anyToStructValue).asJava) + case m: Map[_, _] => + Struct.Value.ofStructValue( + mapToStruct(m.asInstanceOf[Map[String, Any]]) + ) + case null => Struct.Value.ofNullValue() + case p: Product => + Struct.Value.ofStructValue(mapToStruct(productToMap(p))) + case _ => + throw new IllegalArgumentException( + s"Unsupported type: ${value.getClass}" + ) + } + } + + value match { + case Some(v) => anyToStructureValue0(v) + case None => Struct.Value.ofNullValue() + case _ => anyToStructureValue0(value) + } + } + + mapToStruct(productToMap(product)) + } + + private def toProduct[T <: Product: TypeTag: ClassTag]( + struct: Struct + ): T = { + def structToMap(struct: Struct): Map[String, Any] = { + struct + .fields() + .asScala + .map({ case (key, value) => + (key, structValueToAny(value)) + }) + .toMap + } + + def mapToProduct[S <: Product: TypeTag: ClassTag]( + map: Map[String, Any] + ): S = { + val mirror = runtimeMirror(classTag[S].runtimeClass.getClassLoader) + + def valueToParamValue(value: Any, param: Symbol): Any = { + def valueToParamValue0(value: Any, param: Symbol): Any = { + if (param.typeSignature =:= typeOf[Byte]) { + value.asInstanceOf[Double].toByte + } else if (param.typeSignature =:= typeOf[Short]) { + value.asInstanceOf[Double].toShort + } else if (param.typeSignature =:= typeOf[Int]) { + value.asInstanceOf[Double].toInt + } else if (param.typeSignature =:= typeOf[Long]) { + value.asInstanceOf[Double].toLong + } else if (param.typeSignature =:= typeOf[Float]) { + value.asInstanceOf[Double].toFloat + } else if (param.typeSignature <:< typeOf[Product]) { + val typeTag = createTypeTag(param.typeSignature) + val classTag = ClassTag( + typeTag.mirror.runtimeClass(param.typeSignature) + ) + mapToProduct(value.asInstanceOf[Map[String, Any]])( + typeTag, + classTag + ) + } else { + value + } + } + + if (param.typeSignature <:< typeOf[Option[Any]]) { + Some( + valueToParamValue0( + value, + param.typeSignature.dealias.typeArgs.head.typeSymbol + ) + ) + } else { + valueToParamValue0(value, param) + } + } + + def createTypeTag[U <: Product](tpe: Type): TypeTag[U] = { + val typSym = mirror.staticClass(tpe.typeSymbol.fullName) + // note: this uses internal API, otherwise we will need to depend on scala-compiler at runtime + val typeRef = + universe.internal.typeRef(NoPrefix, typSym, List.empty) + + TypeTag( + mirror, + new TypeCreator { + override def apply[V <: Universe with Singleton]( + m: Mirror[V] + ): V#Type = { + assert( + m == mirror, + s"TypeTag[$typeRef] defined in $mirror cannot be migrated to $m." + ) + typeRef.asInstanceOf[V#Type] + } + } + ) + } + + val clazz = typeOf[S].typeSymbol.asClass + val classMirror = mirror.reflectClass(clazz) + val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod + val constructorMirror = classMirror.reflectConstructor(constructor) + + val constructorArgs = + constructor.paramLists.flatten.map((param: Symbol) => { + val paramName = param.name.toString + val value = map.getOrElse( + paramName, + throw new IllegalArgumentException( + s"Map is missing required parameter named $paramName" + ) + ) + valueToParamValue(value, param) + }) + + constructorMirror(constructorArgs: _*).asInstanceOf[S] + } + + def structValueToAny(value: Struct.Value): Any = { + value.kind() match { + case Struct.Value.Kind.STRING_VALUE => value.stringValue() + case Struct.Value.Kind.NUMBER_VALUE => value.numberValue() + case Struct.Value.Kind.BOOL_VALUE => value.boolValue() + case Struct.Value.Kind.LIST_VALUE => + value.listValue().asScala.map(structValueToAny).toList + case Struct.Value.Kind.STRUCT_VALUE => structToMap(value.structValue()) + case Struct.Value.Kind.NULL_VALUE => None + } + } + + mapToProduct[T](structToMap(struct)) + } + /** Returns a [[SdkLiteralType]] for blob. * * @return diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index 25f03122a..b64e0ae71 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -30,6 +30,8 @@ import org.flyte.flytekit.{ import scala.annotation.implicitNotFound import scala.collection.JavaConverters._ +import scala.reflect.{ClassTag, classTag} +import scala.reflect.runtime.universe.{TypeTag, typeOf} /** Type class to map between Flyte `Variable` and `Literal` and Scala case * classes. @@ -231,6 +233,12 @@ object SdkScalaType { implicit def durationLiteralType: SdkScalaLiteralType[Duration] = DelegateLiteralType(SdkLiteralTypes.durations()) + // fixme: using Product is just an approximation for case class because Product + // is also super class of, for example, Option and Tuple + implicit def productLiteralType[T <: Product: TypeTag: ClassTag] + : SdkScalaLiteralType[T] = + DelegateLiteralType(SdkLiteralTypes.generics()) + // fixme: create blob type from annotation, or rethink how we could offer the offloaded data feature // https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype implicit def blobLiteralType: SdkScalaLiteralType[Blob] = diff --git a/integration-tests/src/test/java/org/flyte/AdditionalIT.java b/integration-tests/src/test/java/org/flyte/AdditionalIT.java index 00e50c27a..5355ddb71 100644 --- a/integration-tests/src/test/java/org/flyte/AdditionalIT.java +++ b/integration-tests/src/test/java/org/flyte/AdditionalIT.java @@ -16,25 +16,21 @@ */ package org.flyte; -import static org.flyte.FlyteContainer.CLIENT; +import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import flyteidl.core.Literals; +import flyteidl.core.Literals.LiteralMap; import org.flyte.utils.Literal; -import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.testcontainers.shaded.com.google.common.collect.ImmutableMap; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class AdditionalIT { - @BeforeAll - public static void beforeAll() { - CLIENT.registerWorkflows("integration-tests/target/lib"); - } - +class AdditionalIT extends Fixtures { @ParameterizedTest @CsvSource({ "0,0,0,0,a == b && c == d", @@ -47,7 +43,7 @@ public static void beforeAll() { "0,1,0,1,a < b && c < d", "1,0,0,1,a > b && c < d", }) - public void testBranchNodeWorkflow(long a, long b, long c, long d, String expected) { + void testBranchNodeWorkflow(long a, long b, long c, long d, String expected) { Literals.LiteralMap output = CLIENT.createExecution( "org.flyte.integrationtests.BranchNodeWorkflow", @@ -66,7 +62,7 @@ public void testBranchNodeWorkflow(long a, long b, long c, long d, String expect "table-exists,true", "non-existent,false", }) - public void testStructs(String name, boolean expected) { + void testStructs(String name, boolean expected) { Literals.LiteralMap output = CLIENT.createExecution( "org.flyte.integrationtests.structs.MockPipelineWorkflow", @@ -74,4 +70,12 @@ public void testStructs(String name, boolean expected) { assertThat(output, equalTo(Literal.ofBooleanMap(ImmutableMap.of("exists", expected)))); } + + @Test + void testStructsScala() { + Literals.LiteralMap output = + CLIENT.createExecution("NestedIOWorkflowLaunchPlan", STAGING_DOMAIN); + + assertThat(output, equalTo(LiteralMap.getDefaultInstance())); + } } diff --git a/integration-tests/src/test/java/org/flyte/FlyteContainer.java b/integration-tests/src/test/java/org/flyte/Fixtures.java similarity index 60% rename from integration-tests/src/test/java/org/flyte/FlyteContainer.java rename to integration-tests/src/test/java/org/flyte/Fixtures.java index a95fe7d76..d80bec77d 100644 --- a/integration-tests/src/test/java/org/flyte/FlyteContainer.java +++ b/integration-tests/src/test/java/org/flyte/Fixtures.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 Flyte Authors. + * Copyright 2023 Flyte Authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,16 @@ */ package org.flyte; +import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN; + import org.flyte.utils.FlyteSandboxClient; -public class FlyteContainer { - static final FlyteSandboxClient CLIENT = FlyteSandboxClient.create(); +class Fixtures { + protected static final FlyteSandboxClient CLIENT = FlyteSandboxClient.create(); + + static { + CLIENT.registerWorkflows("integration-tests/target/lib"); + CLIENT.registerWorkflows("flytekit-examples/target/lib"); + CLIENT.registerWorkflows("flytekit-examples-scala/target/lib", STAGING_DOMAIN); + } } diff --git a/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java b/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java index cbce85dee..c4ec8151e 100644 --- a/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java +++ b/integration-tests/src/test/java/org/flyte/JavaExamplesIT.java @@ -16,29 +16,17 @@ */ package org.flyte; -import static org.flyte.FlyteContainer.CLIENT; -import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN; import static org.flyte.utils.Literal.ofIntegerMap; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import flyteidl.core.Literals; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.testcontainers.shaded.com.google.common.collect.ImmutableMap; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class JavaExamplesIT { - private static final String CLASSPATH_EXAMPLES = "flytekit-examples/target/lib"; - private static final String CLASSPATH_EXAMPLES_SCALA = "flytekit-examples-scala/target/lib"; - - @BeforeAll - public static void beforeAll() { - CLIENT.registerWorkflows(CLASSPATH_EXAMPLES); - CLIENT.registerWorkflows(CLASSPATH_EXAMPLES_SCALA, STAGING_DOMAIN); - } - +class JavaExamplesIT extends Fixtures { @Test public void testSumTask() { Literals.LiteralMap output = @@ -53,7 +41,7 @@ public void testSumTask() { } @Test - public void testFibonacciWorkflow() { + void testFibonacciWorkflow() { Literals.LiteralMap output = CLIENT.createExecution( "org.flyte.examples.FibonacciWorkflow", @@ -66,7 +54,7 @@ public void testFibonacciWorkflow() { } @Test - public void testDynamicFibonacciWorkflow() { + void testDynamicFibonacciWorkflow() { Literals.LiteralMap output = CLIENT.createExecution( "org.flyte.examples.DynamicFibonacciWorkflow", ofIntegerMap(ImmutableMap.of("n", 2L))); diff --git a/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java b/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java index 9d9caac47..888416893 100644 --- a/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java +++ b/integration-tests/src/test/java/org/flyte/SerializeJavaIT.java @@ -16,7 +16,6 @@ */ package org.flyte; -import static org.flyte.FlyteContainer.CLIENT; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; @@ -29,13 +28,13 @@ import org.junit.jupiter.api.io.TempDir; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class SerializeJavaIT { +class SerializeJavaIT extends Fixtures { private static final String CLASSPATH = "flytekit-examples/target/lib"; @TempDir Path managed; @Test - public void testSerializeWorkflows() { + void testSerializeWorkflows() { try { File current = new File("target/protos"); File tempDir = managed.resolve(current.getAbsolutePath()).toFile(); diff --git a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java index f614ef596..a2c04d119 100644 --- a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java +++ b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java @@ -22,7 +22,9 @@ import flyteidl.admin.ExecutionOuterClass; import flyteidl.core.Execution; import flyteidl.core.IdentifierOuterClass; +import flyteidl.core.IdentifierOuterClass.ResourceType; import flyteidl.core.Literals; +import flyteidl.core.Literals.LiteralMap; import flyteidl.service.AdminServiceGrpc; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; @@ -55,6 +57,19 @@ public static FlyteSandboxClient create() { return new FlyteSandboxClient(version, stub); } + public Literals.LiteralMap createExecution(String name, String domain) { + return createExecution( + IdentifierOuterClass.Identifier.newBuilder() + .setResourceType(ResourceType.LAUNCH_PLAN) + .setDomain(domain) + .setProject(PROJECT) + .setName(name) + .setVersion(version) + .build(), + LiteralMap.getDefaultInstance(), + domain); + } + public Literals.LiteralMap createTaskExecution(String name, Literals.LiteralMap inputs) { return createExecution( IdentifierOuterClass.Identifier.newBuilder() @@ -64,7 +79,8 @@ public Literals.LiteralMap createTaskExecution(String name, Literals.LiteralMap .setName(name) .setVersion(version) .build(), - inputs); + inputs, + DEVELOPMENT_DOMAIN); } public Literals.LiteralMap createExecution(String name, Literals.LiteralMap inputs) { @@ -76,15 +92,16 @@ public Literals.LiteralMap createExecution(String name, Literals.LiteralMap inpu .setName(name) .setVersion(version) .build(), - inputs); + inputs, + DEVELOPMENT_DOMAIN); } private Literals.LiteralMap createExecution( - IdentifierOuterClass.Identifier id, Literals.LiteralMap inputs) { + IdentifierOuterClass.Identifier id, Literals.LiteralMap inputs, String domain) { ExecutionOuterClass.ExecutionCreateResponse response = stub.createExecution( ExecutionOuterClass.ExecutionCreateRequest.newBuilder() - .setDomain(DEVELOPMENT_DOMAIN) + .setDomain(domain) .setProject(PROJECT) .setInputs(inputs) .setSpec(ExecutionOuterClass.ExecutionSpec.newBuilder().setLaunchPlan(id).build()) diff --git a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxContainer.java b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxContainer.java index c64c948bb..85561ed69 100644 --- a/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxContainer.java +++ b/integration-tests/src/test/java/org/flyte/utils/FlyteSandboxContainer.java @@ -58,6 +58,8 @@ private static void startContainer() { IOUtils.copy(imageInputStream, outputStream); } + Thread.sleep(1000); + ExecResult execResult = INSTANCE.execInContainer( "docker", "load", "-i", "integration-tests/target/jflyte.tar.gz"); diff --git a/pom.xml b/pom.xml index 0b6208c8e..160e4fca9 100644 --- a/pom.xml +++ b/pom.xml @@ -374,7 +374,7 @@ net.java.dev.jna jna - 5.8.0 + 5.9.0 com.fasterxml.jackson