Skip to content

Commit

Permalink
fix: recursively translate input value strings to enums (#1136)
Browse files Browse the repository at this point in the history
  • Loading branch information
frekw authored Nov 11, 2021
1 parent 7cbe359 commit 93aa19f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 21 deletions.
82 changes: 61 additions & 21 deletions core/src/main/scala/caliban/parsing/VariablesUpdater.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package caliban.parsing

import caliban.GraphQLRequest
import caliban.InputValue.ListValue
import caliban.Value.StringValue
import caliban.introspection.adt._
import caliban.parsing.adt.Definition.ExecutableDefinition.OperationDefinition
import caliban.parsing.adt.Type.{ ListType, NamedType }
import caliban.parsing.adt._
import caliban.schema.RootType
import caliban.{ InputValue, Value }
Expand All @@ -16,38 +19,75 @@ object VariablesUpdater {
): GraphQLRequest = {
val variableDefinitions = doc.operationDefinitions.flatMap(_.variableDefinitions)
val updated = req.variables.getOrElse(Map.empty).map { case (key, value) =>
val v = variableDefinitions.find(_.name == key).map(resolveEnumValues(value, _, rootType)).getOrElse(value)
val v =
variableDefinitions
.find(_.name == key)
.map { definition =>
rewriteValues(value, definition.variableType, rootType)
}
.getOrElse(value)

key -> v
}

req.copy(variables = Some(updated))
}

private def rewriteValues(value: InputValue, `type`: Type, rootType: RootType): InputValue =
`type` match {
case ListType(ofType, _) =>
value match {
case ListValue(values) =>
ListValue(values.map(v => rewriteValues(v, ofType, rootType)))
case _ => value
}
case NamedType(name, _) =>
rootType.types.get(name).map(t => resolveEnumValues(value, t, rootType)).getOrElse(value)
}

// Since we cannot separate a String from an Enum when variables
// are parsed, we need to translate from strings to enums here
// if we have a valid enum field.
private def resolveEnumValues(
value: InputValue,
definition: VariableDefinition,
typ: __Type,
rootType: RootType
): InputValue = {
val t = Type
.innerType(definition.variableType)

rootType.types
.get(t)
.map(_.kind)
.flatMap { kind =>
(kind, value) match {
case (__TypeKind.ENUM, InputValue.ListValue(v)) =>
Some(
InputValue.ListValue(v.map(resolveEnumValues(_, definition, rootType)))
)
case (__TypeKind.ENUM, Value.StringValue(v)) =>
Some(Value.EnumValue(v))
case _ => None
): InputValue =
typ.kind match {
case __TypeKind.INPUT_OBJECT =>
value match {
case InputValue.ObjectValue(fields) =>
val defs = typ.inputFields.getOrElse(List.empty)
InputValue.ObjectValue(fields.map { case (k, v) =>
val updated =
defs.find(_.name == k).map(field => resolveEnumValues(v, field.`type`(), rootType)).getOrElse(value)

(k, updated)
})
case _ =>
value
}
}
.getOrElse(value)
}

case __TypeKind.LIST =>
value match {
case ListValue(values) =>
typ.ofType
.map(innerType => ListValue(values.map(value => resolveEnumValues(value, innerType, rootType))))
.getOrElse(value)
case _ => value
}

case __TypeKind.NON_NULL =>
typ.ofType
.map(innerType => resolveEnumValues(value, innerType, rootType))
.getOrElse(value)

case __TypeKind.ENUM =>
value match {
case StringValue(value) => Value.EnumValue(value)
case _ => value
}
case _ =>
value
}
}
38 changes: 38 additions & 0 deletions core/src/test/scala/caliban/execution/FieldArgsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,44 @@ object FieldArgsSpec extends DefaultRunnableSpec {
interpreter <- api.interpreter
res <- interpreter.execute(query)
} yield assert(res.errors.headOption)(isSome(anything))
},
testM("it correctly handles lists of objects with enums") {
case class QueryInput(filter: List[Filter])
case class Filter(color: COLOR)
case class Query(query: QueryInput => String)
val query =
"""query MyQuery($filter: [FilterInput!]!) {
| query(filter: $filter)
|}""".stripMargin

val api = graphQL(
RootResolver(
Query(
query = q => q.filter.headOption.map(_.color.toString).getOrElse("Missing")
)
)
)

for {
interpreter <- api.interpreter
res <- interpreter.executeRequest(
request = GraphQLRequest(
query = Some(query),
variables = Some(
Map(
"filter" ->
InputValue.ListValue(
List(
InputValue.ObjectValue(
Map("color" -> Value.StringValue("BLUE"))
)
)
)
)
)
)
)
} yield assertTrue(res.data.toString == "{\"query\":\"BLUE\"}")
}
)
}

0 comments on commit 93aa19f

Please sign in to comment.