Skip to content

Commit

Permalink
Propagate type condition from nested fragments to parent (#2081)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou authored Jan 16, 2024
1 parent fb6d7a4 commit b57487f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
50 changes: 31 additions & 19 deletions core/src/main/scala/caliban/execution/Field.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import caliban.schema.{ RootType, Types }
import caliban.{ InputValue, Value }

import scala.collection.mutable
import scala.jdk.CollectionConverters._

/**
* Represents a field used during the execution of a query
Expand Down Expand Up @@ -136,7 +135,7 @@ object Field {
directives: List[Directive],
rootType: RootType
): Field = {
val memoizedFragments = new mutable.HashMap[String, (List[Field], Option[String])]()
val memoizedFragments = new mutable.HashMap[String, (List[(Field, Option[String])])]()
val variableDefinitionsMap = variableDefinitions.map(v => v.name -> v).toMap

def loop(
Expand All @@ -145,7 +144,7 @@ object Field {
fragment: Option[Fragment],
targets: Option[Set[String]],
condition: Option[Set[String]]
): List[Field] = {
): List[(Field, Option[String])] = {
val map = new java.util.LinkedHashMap[(String, Option[String]), Field]()

def addField(f: Field, condition: Option[String]): Unit =
Expand Down Expand Up @@ -174,7 +173,7 @@ object Field {
t,
Some(innerType),
alias,
fields,
fields.map(_._1),
targets = targets,
arguments = resolveVariables(arguments, variableDefinitionsMap, variableValues),
directives = resolvedDirectives,
Expand All @@ -186,47 +185,60 @@ object Field {
)
}
case FragmentSpread(name, directives) =>
val (fields, condition) = memoizedFragments.getOrElseUpdate(
val fields = memoizedFragments.getOrElseUpdate(
name, {
val resolvedDirectives = directives.map(resolveDirectiveVariables(variableValues, variableDefinitionsMap))
val _fields = if (checkDirectives(resolvedDirectives)) {
fragments.get(name).map { f =>
val t = rootType.types.getOrElse(f.typeCondition.name, fieldType)
val typeCondName = f.typeCondition.name
val t = rootType.types.getOrElse(typeCondName, fieldType)
val subtypeNames0 = subtypeNames(typeCondName, rootType)
val isSubsetCondition = subtypeNames0.getOrElse(Set.empty)
loop(
f.selectionSet,
t,
fragment = Some(Fragment(Some(name), resolvedDirectives)),
targets = Some(Set(f.typeCondition.name)),
condition = subtypeNames(f.typeCondition.name, rootType)
) -> Some(f.typeCondition.name)
targets = Some(Set(typeCondName)),
condition = subtypeNames0
).map {
case t @ (_, Some(c)) if isSubsetCondition(c) => t
case (f1, _) => (f1, Some(typeCondName))
}
}
} else None
_fields.getOrElse(Nil -> None)
_fields.getOrElse(Nil)
}
)
fields.foreach(addField(_, condition))
fields.foreach((addField _).tupled)
case InlineFragment(typeCondition, directives, selectionSet) =>
val resolvedDirectives = directives.map(resolveDirectiveVariables(variableValues, variableDefinitionsMap))
if (checkDirectives(resolvedDirectives)) {
val t = innerType.possibleTypes
.flatMap(_.find(_.name.exists(typeCondition.map(_.name).contains)))
.orElse(typeCondition.flatMap(typeName => rootType.types.get(typeName.name)))
val typeName = typeCondition.map(_.name)
val t = innerType.possibleTypes
.flatMap(_.find(_.name.exists(typeName.contains)))
.orElse(typeName.flatMap(rootType.types.get))
.getOrElse(fieldType)
val typeName = typeCondition.map(_.name)
val subtypeNames0 = typeName.flatMap(subtypeNames(_, rootType))
val isSubsetCondition = subtypeNames0.getOrElse(Set.empty)
loop(
selectionSet,
t,
fragment = Some(Fragment(None, resolvedDirectives)),
targets = typeName.map(Set(_)),
condition = typeName.flatMap(subtypeNames(_, rootType))
).foreach(addField(_, typeName))
condition = subtypeNames0
).foreach { case (f, c) =>
if (c.exists(isSubsetCondition)) addField(f, c)
else addField(f, typeName)
}
}
}
map.values().asScala.toList
val builder = List.newBuilder[(Field, Option[String])]
map.forEach { case ((_, cond), field) => builder += ((field, cond)) }
builder.result()
}

val fields = loop(selectionSet, fieldType, None, None, None)
Field("", fieldType, None, fields = fields, directives = directives)
Field("", fieldType, None, fields = fields.map(_._1), directives = directives)
}

private def resolveDirectiveVariables(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ object SchemaDerivationIssuesSpec extends ZIOSpecDefault {
name
children {
total
nodes { name }
nodes { name bar }
}
}
... on WidgetB {
Expand All @@ -248,7 +248,7 @@ object SchemaDerivationIssuesSpec extends ZIOSpecDefault {
name
children {
total
nodes { name }
nodes { name bar }
}
}
... on WidgetB {
Expand Down Expand Up @@ -574,7 +574,7 @@ object i2076 {
}

@GQLName("WidgetAChild")
case class Child(name: String, foo: String)
case class Child(name: String, foo: String, bar: String)
object Child {
implicit val schema: Schema[Any, Child] = Schema.gen
}
Expand Down

0 comments on commit b57487f

Please sign in to comment.