Skip to content

Commit

Permalink
Resolve directive argument variables (ghostdogpr#1263)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostdogpr authored and Fluxx committed Jan 27, 2022
1 parent b569d98 commit ad0db0b
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 65 deletions.
134 changes: 69 additions & 65 deletions core/src/main/scala/caliban/execution/Field.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,61 +66,76 @@ object Field {

val innerType = Types.innerType(fieldType)
selectionSet.foreach {
case F(alias, name, arguments, directives, selectionSet, index)
if checkDirectives(directives, variableValues) =>
case F(alias, name, arguments, directives, selectionSet, index) =>
val selected = innerType
.fields(__DeprecatedArgs(Some(true)))
.flatMap(_.find(_.name == name))

val schemaDirectives = selected.flatMap(_.directives).getOrElse(Nil)

val t = selected.fold(Types.string)(_.`type`()) // default only case where it's not found is __typename

val field = loop(selectionSet, t)

addField(
Field(
name,
t,
Some(innerType),
alias,
field.fields,
None,
resolveVariables(arguments, variableDefinitions, variableValues),
() => sourceMapper.getLocation(index),
directives ++ schemaDirectives
),
None
val schemaDirectives = selected.flatMap(_.directives).getOrElse(Nil)
val resolvedDirectives = (directives ++ schemaDirectives).map(directive =>
directive.copy(arguments = resolveVariables(directive.arguments, variableDefinitions, variableValues))
)
case FragmentSpread(name, directives) if checkDirectives(directives, variableValues) =>
fragments
.get(name)
.foreach { f =>
val t =
innerType.possibleTypes.flatMap(_.find(_.name.contains(f.typeCondition.name))).getOrElse(fieldType)
loop(f.selectionSet, t).fields
.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(f.typeCondition.name, rootType))
)
.foreach(addField(_, Some(f.typeCondition.name)))

if (checkDirectives(resolvedDirectives)) {
val t = selected.fold(Types.string)(_.`type`()) // default only case where it's not found is __typename

val field = loop(selectionSet, t)

addField(
Field(
name,
t,
Some(innerType),
alias,
field.fields,
None,
resolveVariables(arguments, variableDefinitions, variableValues),
() => sourceMapper.getLocation(index),
resolvedDirectives
),
None
)
}
case FragmentSpread(name, directives) =>
val resolvedDirectives = directives.map(directive =>
directive.copy(arguments = resolveVariables(directive.arguments, variableDefinitions, variableValues))
)

if (checkDirectives(resolvedDirectives)) {
fragments
.get(name)
.foreach { f =>
val t =
innerType.possibleTypes.flatMap(_.find(_.name.contains(f.typeCondition.name))).getOrElse(fieldType)
loop(f.selectionSet, t).fields
.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(f.typeCondition.name, rootType))
)
.foreach(addField(_, Some(f.typeCondition.name)))
}
}
case InlineFragment(typeCondition, directives, selectionSet) =>
val resolvedDirectives = directives.map(directive =>
directive.copy(arguments = resolveVariables(directive.arguments, variableDefinitions, variableValues))
)

if (checkDirectives(resolvedDirectives)) {
val t = innerType.possibleTypes
.flatMap(_.find(_.name.exists(typeCondition.map(_.name).contains)))
.getOrElse(fieldType)
val field = loop(selectionSet, t)
typeCondition match {
case None => if (field.fields.nonEmpty) fieldList ++= field.fields
case Some(typeName) =>
field.fields
.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(typeName.name, rootType))
)
.foreach(addField(_, Some(typeName.name)))
}
case InlineFragment(typeCondition, directives, selectionSet) if checkDirectives(directives, variableValues) =>
val t = innerType.possibleTypes
.flatMap(_.find(_.name.exists(typeCondition.map(_.name).contains)))
.getOrElse(fieldType)
val field = loop(selectionSet, t)
typeCondition match {
case None => if (field.fields.nonEmpty) fieldList ++= field.fields
case Some(typeName) =>
field.fields
.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(typeName.name, rootType))
)
.foreach(addField(_, Some(typeName.name)))
}
case _ =>
}
Field("", fieldType, None, fields = fieldList.toList)
}
Expand Down Expand Up @@ -160,26 +175,15 @@ object Field {
) + typeName
)

private def checkDirectives(directives: List[Directive], variableValues: Map[String, InputValue]): Boolean =
!checkDirective("skip", default = false, directives, variableValues) &&
checkDirective("include", default = true, directives, variableValues)
private def checkDirectives(directives: List[Directive]): Boolean =
!checkDirective("skip", default = false, directives) &&
checkDirective("include", default = true, directives)

private def checkDirective(
name: String,
default: Boolean,
directives: List[Directive],
variableValues: Map[String, InputValue]
): Boolean =
private def checkDirective(name: String, default: Boolean, directives: List[Directive]): Boolean =
directives
.find(_.name == name)
.flatMap(_.arguments.get("if")) match {
case Some(BooleanValue(value)) => value
case Some(InputValue.VariableValue(name)) =>
variableValues
.get(name) match {
case Some(BooleanValue(value)) => value
case _ => default
}
case _ => default
case Some(BooleanValue(value)) => value
case _ => default
}
}
21 changes: 21 additions & 0 deletions core/src/test/scala/caliban/execution/ExecutionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,27 @@ object ExecutionSpec extends DefaultRunnableSpec {
api.interpreter.flatMap(_.execute(query, None, Map("term" -> StringValue("search")))).map(_.asJson.noSpaces)
)(equalTo("""{"data":{"getId":null}}"""))
},
testM("default values for variables in directives") {
import io.circe.syntax._

case class TestQuery(field1: String, field2: String)
case class Query(test: TestQuery)
val api = graphQL(RootResolver(Query(TestQuery(field1 = "1234", field2 = "5421"))))

val query =
"""
|query ($a: Boolean = true, $b: Boolean = false) {
| test {
| field1 @include(if: $a)
| field2 @include(if: $b)
| }
|}
|""".stripMargin

assertM(
api.interpreter.flatMap(_.execute(query, None, Map())).map(_.asJson.noSpaces)
)(equalTo("""{"data":{"test":{"field1":"1234"}}}"""))
},
testM("respects variables that are not provided") {
sealed trait ThreeState
object ThreeState {
Expand Down

0 comments on commit ad0db0b

Please sign in to comment.