diff --git a/core/src/main/scala/caliban/execution/Field.scala b/core/src/main/scala/caliban/execution/Field.scala index 6d7721e48..88320c8c5 100644 --- a/core/src/main/scala/caliban/execution/Field.scala +++ b/core/src/main/scala/caliban/execution/Field.scala @@ -97,20 +97,21 @@ object Field { variableDefinitions: List[VariableDefinition], variableValues: Map[String, InputValue] ): Map[String, InputValue] = { - def resolveVariable(value: InputValue): InputValue = + def resolveVariable(value: InputValue): Option[InputValue] = value match { - case InputValue.ListValue(values) => InputValue.ListValue(values.map(resolveVariable)) + case InputValue.ListValue(values) => + Some(InputValue.ListValue(values.flatMap(resolveVariable))) case InputValue.ObjectValue(fields) => - InputValue.ObjectValue(fields.map { case (k, v) => k -> resolveVariable(v) }) + Some(InputValue.ObjectValue(fields.flatMap { case (k, v) => resolveVariable(v).map(k -> _) })) case InputValue.VariableValue(name) => - (for { - definition <- variableDefinitions.find(_.name == name) - defaultValue = definition.defaultValue getOrElse NullValue - value = variableValues.getOrElse(name, defaultValue) - } yield value) getOrElse NullValue - case value: Value => value + for { + definition <- variableDefinitions.find(_.name == name) + value <- variableValues.get(name).orElse(definition.defaultValue) + } yield value + case value: Value => + Some(value) } - arguments.map { case (k, v) => k -> resolveVariable(v) } + arguments.flatMap { case (k, v) => resolveVariable(v).map(k -> _) } } private def subtypeNames(typeName: String, rootType: RootType): Option[List[String]] = diff --git a/core/src/test/scala/caliban/execution/ExecutionSpec.scala b/core/src/test/scala/caliban/execution/ExecutionSpec.scala index fcff4030e..58581e00a 100644 --- a/core/src/test/scala/caliban/execution/ExecutionSpec.scala +++ b/core/src/test/scala/caliban/execution/ExecutionSpec.scala @@ -6,7 +6,7 @@ import caliban.CalibanError.ExecutionError import caliban.GraphQL._ import caliban.Macros.gqldoc import caliban.TestUtils._ -import caliban.Value.{ BooleanValue, IntValue, StringValue } +import caliban.Value.{ BooleanValue, IntValue, NullValue, StringValue } import caliban.introspection.adt.__Type import caliban.parsing.adt.LocationInfo import caliban.schema.Annotations.{ GQLInterface, GQLName, GQLValueType } @@ -147,6 +147,62 @@ object ExecutionSpec extends DefaultRunnableSpec { api.interpreter.flatMap(_.execute(query, None, Map("term" -> StringValue("search")))).map(_.asJson.noSpaces) )(equalTo("""{"data":{"getId":null}}""")) }, + testM("respects variables that are not provided") { + sealed trait ThreeState + object ThreeState { + case object Undefined extends ThreeState + case object Null extends ThreeState + case object Value extends ThreeState + + def fromOption[T](o: Option[T]) = o.fold[ThreeState](Null)(_ => Value) + + implicit val schema: Schema[Any, ThreeState] = Schema.optionSchema(Schema.booleanSchema).contramap { + case Undefined => None + case Null => Some(false) + case Value => Some(true) + } + implicit val argBuilder: ArgBuilder[ThreeState] = new ArgBuilder[ThreeState] { + private val base = ArgBuilder.option(ArgBuilder.boolean) + + override def build(input: InputValue) = base.build(input).map(fromOption(_)) + override def buildMissing(default: Option[String]) = default match { + case None => Right(Undefined) + case Some(v) => base.buildMissing(Some(v)).map(fromOption(_)) + } + } + } + + case class Args(term: String, state: ThreeState) + case class Test(getState: Args => ThreeState) + val api = graphQL(RootResolver(Test(_.state))) + val query = """query test($term: String!, $state: Boolean) { getState(term: $term, state: $state) }""" + val queryDefault = + """query test($term: String!, $state: Boolean = null) { getState(term: $term, state: $state) }""" + + def execute(query: String, state: ThreeState) = { + val vars = Map( + "term" -> Some(StringValue("search")), + "state" -> (state match { + case ThreeState.Undefined => None + case ThreeState.Null => Some(NullValue) + case ThreeState.Value => Some(BooleanValue(false)) + }) + ).collect { case (k, Some(v)) => k -> v } + api.interpreter.flatMap(_.execute(query, None, vars)) + } + + for { + undefined <- execute(query, ThreeState.Undefined) + nul <- execute(query, ThreeState.Null) + value <- execute(query, ThreeState.Value) + default <- execute(queryDefault, ThreeState.Undefined) + defaultValue <- execute(queryDefault, ThreeState.Value) + } yield assertTrue(undefined.data.toString == """{"getState":null}""") && + assertTrue(nul.data.toString == """{"getState":false}""") && + assertTrue(value.data.toString == """{"getState":true}""") && + assertTrue(default.data.toString == """{"getState":false}""") && + assertTrue(defaultValue.data.toString == """{"getState":true}""") + }, testM("field function") { import io.circe.syntax._