Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Option correctly #266

Merged
merged 6 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ org.flyte.examples.flytekitscala.GreetTask
org.flyte.examples.flytekitscala.AddQuestionTask
org.flyte.examples.flytekitscala.NoInputsTask
org.flyte.examples.flytekitscala.NestedIOTask
org.flyte.examples.flytekitscala.NestedIOTaskNoop
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ class LaunchPlanRegistry extends SimpleSdkLaunchPlanRegistry {
6.toDouble,
"hello",
List("1", "2"),
List(NestedNested(7.toDouble, NestedNestedNested("world"))),
List(NestedNested(7.toDouble, Some(NestedNestedNested("world")))),
Map("1" -> "1", "2" -> "2"),
Map("foo" -> NestedNested(7.toDouble, NestedNestedNested("world"))),
Map(
"foo" -> NestedNested(
7.toDouble,
Some(NestedNestedNested("world"))
)
),
Some(false),
None,
Some(List("3", "4")),
Some(Map("3" -> "3", "4" -> "4")),
NestedNested(7.toDouble, NestedNestedNested("world"))
NestedNested(7.toDouble, Some(NestedNestedNested("world")))
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.flyte.flytekitscala.{
}

case class NestedNestedNested(string: String)
case class NestedNested(double: Double, nested: NestedNestedNested)
case class NestedNested(double: Double, nested: Option[NestedNestedNested])
case class Nested(
boolean: Boolean,
byte: Byte,
Expand Down Expand Up @@ -57,9 +57,6 @@ case class NestedIOTaskOutput(
generic: SdkBindingData[Nested]
)

/** Example Flyte task that takes a name as the input and outputs a simple
* greeting message.
*/
class NestedIOTask
extends SdkRunnableTask[
NestedIOTaskInput,
Expand All @@ -69,17 +66,21 @@ class NestedIOTask
SdkScalaType[NestedIOTaskOutput]
) {

/** Defines task behavior. This task takes a name as the input, wraps it in a
* welcome message, and outputs the message.
*
* @param input
* the name of the person to be greeted
* @return
* the welcome message
*/
override def run(input: NestedIOTaskInput): NestedIOTaskOutput =
NestedIOTaskOutput(
input.name,
input.generic
)
}

class NestedIOTaskNoop
extends SdkRunnableTask[
NestedIOTaskOutput,
NestedIOTaskOutput
](
SdkScalaType[NestedIOTaskOutput],
SdkScalaType[NestedIOTaskOutput]
) {

override def run(input: NestedIOTaskOutput): NestedIOTaskOutput = input
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class NestedIOWorkflow
builder: SdkScalaWorkflowBuilder,
input: NestedIOTaskInput
): Unit = {
builder.apply(new NestedIOTask(), input)
val output = builder.apply(new NestedIOTask(), input)
builder.apply(new NestedIOTaskNoop(), output.getOutputs)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ import org.flyte.flytekitscala.SdkLiteralTypes.{
}

// The constructor is reflectedly invoked so it cannot be an inner class
case class ScalarNested(foo: String, bar: String)
case class ScalarNested(
foo: String,
bar: Option[String],
nestedNested: Option[ScalarNestedNested]
)
case class ScalarNestedNested(foo: String, bar: Option[String])

class SdkScalaTypeTest {

Expand Down Expand Up @@ -178,7 +183,15 @@ class SdkScalaTypeTest {
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
"bar" -> Struct.Value.ofNullValue(),
"nestedNested" -> Struct.Value.ofStructValue(
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
).asJava
)
)
).asJava
)
)
Expand All @@ -196,7 +209,11 @@ class SdkScalaTypeTest {
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
None,
Some(ScalarNestedNested("foo", Some("bar")))
)
)
)

Expand All @@ -218,7 +235,11 @@ class SdkScalaTypeTest {
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
)

Expand All @@ -245,7 +266,15 @@ class SdkScalaTypeTest {
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
"bar" -> Struct.Value.ofStringValue("bar"),
"nestedNested" -> Struct.Value.ofStructValue(
Struct.of(
Map(
"foo" -> Struct.Value.ofStringValue("foo"),
"bar" -> Struct.Value.ofStringValue("bar")
).asJava
)
)
).asJava
)
)
Expand Down Expand Up @@ -285,7 +314,11 @@ class SdkScalaTypeTest {
blob = SdkBindingDataFactory.of(blob),
generic = SdkBindingDataFactory.of(
SdkLiteralTypes.generics(),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
)

Expand All @@ -301,7 +334,11 @@ class SdkScalaTypeTest {
"blob" -> SdkBindingDataFactory.of(blob),
"generic" -> SdkBindingDataFactory.of(
SdkLiteralTypes.generics[ScalarNested](),
ScalarNested("foo", "bar")
ScalarNested(
"foo",
Some("bar"),
Some(ScalarNestedNested("foo", Some("bar")))
)
)
).asJava

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,41 +297,39 @@ object SdkLiteralTypes {
): S = {
val mirror = runtimeMirror(classTag[S].runtimeClass.getClassLoader)

def valueToParamValue(value: Any, param: Symbol): Any = {
def valueToParamValue0(value: Any, param: Symbol): Any = {
if (param.typeSignature =:= typeOf[Byte]) {
value.asInstanceOf[Double].toByte
} else if (param.typeSignature =:= typeOf[Short]) {
value.asInstanceOf[Double].toShort
} else if (param.typeSignature =:= typeOf[Int]) {
value.asInstanceOf[Double].toInt
} else if (param.typeSignature =:= typeOf[Long]) {
value.asInstanceOf[Double].toLong
} else if (param.typeSignature =:= typeOf[Float]) {
value.asInstanceOf[Double].toFloat
} else if (param.typeSignature <:< typeOf[Product]) {
val typeTag = createTypeTag(param.typeSignature)
val classTag = ClassTag(
typeTag.mirror.runtimeClass(param.typeSignature)
)
mapToProduct(value.asInstanceOf[Map[String, Any]])(
typeTag,
classTag
)
def valueToParamValue(value: Any, tpe: Type): Any = {
if (tpe =:= typeOf[Byte]) {
value.asInstanceOf[Double].toByte
} else if (tpe =:= typeOf[Short]) {
value.asInstanceOf[Double].toShort
} else if (tpe =:= typeOf[Int]) {
value.asInstanceOf[Double].toInt
} else if (tpe =:= typeOf[Long]) {
value.asInstanceOf[Double].toLong
} else if (tpe =:= typeOf[Float]) {
value.asInstanceOf[Double].toFloat
} else if (tpe <:< typeOf[Option[Any]]) { // this has to be before Product check because Option is a Product
if (value == None) { // None is used to represent Struct.Value.Kind.NULL_VALUE when converting struct to map
None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another bug fix. We have been converting things into Some(None), due to type conformation test failure against T in Option[T], which is exactly what this PR set off to fix. Bug on top of bug :(

Hopefully the extended unit tests and ITs cover better.

} else {
value
}
}

if (param.typeSignature <:< typeOf[Option[Any]]) {
Some(
valueToParamValue0(
value,
param.typeSignature.dealias.typeArgs.head.typeSymbol
Some(
valueToParamValue(
value,
tpe.dealias.typeArgs.head
)
)
}
} else if (tpe <:< typeOf[Product]) {
val typeTag = createTypeTag(tpe)
val classTag = ClassTag(
typeTag.mirror.runtimeClass(tpe)
)
mapToProduct(value.asInstanceOf[Map[String, Any]])(
typeTag,
classTag
)
} else {
valueToParamValue0(value, param)
value
}
}

Expand Down Expand Up @@ -371,7 +369,7 @@ object SdkLiteralTypes {
s"Map is missing required parameter named $paramName"
)
)
valueToParamValue(value, param)
valueToParamValue(value, param.typeSignature.dealias)
})

constructorMirror(constructorArgs: _*).asInstanceOf[S]
Expand Down
Loading