Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fields merging and improve performance #1199

Merged
merged 2 commits into from
Dec 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 5 additions & 26 deletions core/src/main/scala/caliban/execution/Executor.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package caliban.execution

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import caliban.CalibanError.ExecutionError
import caliban.ResponseValue._
import caliban.Value._
@@ -52,7 +51,7 @@ object Executor {
value match {
case EnumValue(v) =>
// special case of an hybrid union containing case objects, those should return an object instead of a string
val obj = mergeFields(currentField, v).collectFirst {
val obj = filterFields(currentField, v).collectFirst {
case f: Field if f.name == "__typename" =>
ObjectValue(List(f.alias.getOrElse(f.name) -> StringValue(v)))
case f: Field if f.name == "_" =>
@@ -71,8 +70,8 @@ object Executor {
Types.listOf(currentField.fieldType).fold(false)(_.isNullable)
)
case ObjectStep(objectName, fields) =>
val mergedFields = mergeFields(currentField, objectName)
val items = mergedFields.map {
val filteredFields = filterFields(currentField, objectName)
val items = filteredFields.map {
case f @ Field(name @ "__typename", _, _, alias, _, _, _, _, directives) =>
(alias.getOrElse(name), PureStep(StringValue(objectName)), fieldInfo(f, path, directives))
case f @ Field(name, _, _, alias, _, _, args, _, directives) =>
@@ -173,28 +172,8 @@ object Executor {
private[caliban] def fail(error: CalibanError): UIO[GraphQLResponse[CalibanError]] =
IO.succeed(GraphQLResponse(NullValue, List(error)))

private[caliban] def mergeFields(field: Field, typeName: String): List[Field] = {
// ugly mutable code but it's worth it for the speed ;)
val array = ArrayBuffer.empty[Field]
val map = collection.mutable.Map.empty[String, Int]

field.fields.foreach { field =>
if (field.condition.forall(_.contains(typeName))) {
val name = field.alias.getOrElse(field.name)
map.get(name) match {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map was never updated 🤦‍♂️

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

case None =>
// first time we see this field, add it to the array
array += field
case Some(index) =>
// field already existed, merge it
val f = array(index)
array(index) = f.copy(fields = f.fields ::: field.fields)
}
}
}

array.toList
}
private[caliban] def filterFields(field: Field, typeName: String): List[Field] =
field.fields.filter(_.condition.forall(_.contains(typeName)))

private def fieldInfo(field: Field, path: List[Either[String, Int]], fieldDirectives: List[Directive]): FieldInfo =
FieldInfo(field.alias.getOrElse(field.name), field, path, fieldDirectives)
71 changes: 53 additions & 18 deletions core/src/main/scala/caliban/execution/Field.scala
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
package caliban.execution

import caliban.{ InputValue, Value }
import caliban.Value.{ BooleanValue, NullValue }
import scala.collection.mutable.ArrayBuffer
import caliban.Value.BooleanValue
import caliban.introspection.adt.{ __DeprecatedArgs, __Type }
import caliban.parsing.SourceMapper
import caliban.parsing.adt.Definition.ExecutableDefinition.FragmentDefinition
import caliban.parsing.adt.Selection.{ Field => F, FragmentSpread, InlineFragment }
import caliban.parsing.adt.{ Directive, LocationInfo, Selection, VariableDefinition }
import caliban.schema.{ RootType, Types }
import caliban.{ InputValue, Value }

case class Field(
name: String,
fieldType: __Type,
parentType: Option[__Type],
alias: Option[String] = None,
fields: List[Field] = Nil,
condition: Option[List[String]] = None,
condition: Option[Set[String]] = None,
arguments: Map[String, InputValue] = Map(),
_locationInfo: () => LocationInfo = () => LocationInfo.origin,
directives: List[Directive] = List.empty
@@ -35,7 +36,33 @@ object Field {
rootType: RootType
): Field = {
def loop(selectionSet: List[Selection], fieldType: __Type): Field = {
val fieldList = List.newBuilder[Field]
val fieldList = ArrayBuffer.empty[Field]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we give this a size hint as well? I'd expect that to provide a minor improvement in 99% of cases where people don't duplicate their fields (and really only penalizes us in the pathological case).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we don't know the size without going through fragments recursively.

val map = collection.mutable.Map.empty[String, Int]
var fieldIndex = 0

def addField(f: Field): Unit = {
val name = f.alias.getOrElse(f.name)
map.get(name) match {
case None =>
// first time we see this field, add it to the array
fieldList += f
map.update(name, fieldIndex)
fieldIndex = fieldIndex + 1
case Some(index) =>
// field already existed, merge it
val existing = fieldList(index)
fieldList(index) = existing.copy(
fields = existing.fields ::: f.fields,
condition = (existing.condition, f.condition) match {
case (Some(v1), Some(v2)) => if (v1 == v2) existing.condition else Some(v1 ++ v2)
case (Some(_), None) => existing.condition
case (None, Some(_)) => f.condition
case (None, None) => None
}
)
}
}

val innerType = Types.innerType(fieldType)
selectionSet.foreach {
case F(alias, name, arguments, directives, selectionSet, index)
@@ -49,7 +76,8 @@ object Field {
val t = selected.fold(Types.string)(_.`type`()) // default only case where it's not found is __typename

val field = loop(selectionSet, t)
fieldList +=

addField(
Field(
name,
t,
@@ -61,32 +89,38 @@ object Field {
() => sourceMapper.getLocation(index),
directives ++ schemaDirectives
)
)
case FragmentSpread(name, directives) if checkDirectives(directives, variableValues) =>
fragments
.get(name)
.foreach { f =>
val t =
innerType.possibleTypes.flatMap(_.find(_.name.contains(f.typeCondition.name))).getOrElse(fieldType)
fieldList ++= loop(f.selectionSet, t).fields.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(f.typeCondition.name, rootType))
)
loop(f.selectionSet, t).fields
.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(f.typeCondition.name, rootType))
)
.foreach(addField)
}
case InlineFragment(typeCondition, directives, selectionSet) if checkDirectives(directives, variableValues) =>
val t = innerType.possibleTypes
.flatMap(_.find(_.name.exists(typeCondition.map(_.name).contains)))
.getOrElse(fieldType)
val field = loop(selectionSet, t)
typeCondition match {
case None => fieldList ++= field.fields
case None => if (field.fields.nonEmpty) fieldList ++= field.fields
case Some(typeName) =>
fieldList ++= field.fields.map(field =>
if (field.condition.isDefined) field else field.copy(condition = subtypeNames(typeName.name, rootType))
)
field.fields
.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(typeName.name, rootType))
)
.foreach(addField)
}
case _ =>
}
Field("", fieldType, None, fields = fieldList.result())
Field("", fieldType, None, fields = fieldList.toList)
}

loop(selectionSet, fieldType).copy(directives = directives)
@@ -114,13 +148,14 @@ object Field {
arguments.flatMap { case (k, v) => resolveVariable(v).map(k -> _) }
}

private def subtypeNames(typeName: String, rootType: RootType): Option[List[String]] =
private def subtypeNames(typeName: String, rootType: RootType): Option[Set[String]] =
rootType.types
.get(typeName)
.map(t =>
typeName ::
t.possibleTypes
.fold(List.empty[String])(_.flatMap(_.name.map(subtypeNames(_, rootType).getOrElse(Nil))).flatten)
t.possibleTypes
.fold(Set.empty[String])(
_.map(_.name.map(subtypeNames(_, rootType).getOrElse(Set.empty))).toSet.flatten.flatten
) + typeName
)

private def checkDirectives(directives: List[Directive], variableValues: Map[String, InputValue]): Boolean =
54 changes: 33 additions & 21 deletions core/src/main/scala/caliban/validation/Validator.scala
Original file line number Diff line number Diff line change
@@ -161,23 +161,25 @@ object Validator {
private def collectVariablesUsed(context: Context, selectionSet: List[Selection]): Set[String] = {
def collectValues(selectionSet: List[Selection]): List[InputValue] = {
// ugly mutable code but it's worth it for the speed ;)
val inputValues = List.newBuilder[InputValue]
val inputValues = List.newBuilder[InputValue]
def add(list: Iterable[InputValue]) = if (list.nonEmpty) inputValues ++= list

selectionSet.foreach {
case FragmentSpread(name, directives) =>
directives.foreach(inputValues ++= _.arguments.values)
directives.foreach(d => add(d.arguments.values))
context.fragments
.get(name)
.foreach { f =>
f.directives.foreach(inputValues ++= _.arguments.values)
inputValues ++= collectValues(f.selectionSet)
f.directives.foreach(d => add(d.arguments.values))
add(collectValues(f.selectionSet))
}
case Field(_, _, arguments, directives, selectionSet, _) =>
inputValues ++= arguments.values
directives.foreach(inputValues ++= _.arguments.values)
inputValues ++= collectValues(selectionSet)
add(arguments.values)
directives.foreach(d => add(d.arguments.values))
add(collectValues(selectionSet))
case InlineFragment(_, directives, selectionSet) =>
directives.foreach(inputValues ++= _.arguments.values)
inputValues ++= collectValues(selectionSet)
directives.foreach(d => add(d.arguments.values))
add(collectValues(selectionSet))
}
inputValues.result()
}
@@ -197,7 +199,7 @@ object Validator {
private def collectSelectionSets(selectionSet: List[Selection]): List[Selection] = {
val sets = List.newBuilder[Selection]
def loop(selectionSet: List[Selection]): Unit = {
sets ++= selectionSet
if (selectionSet.nonEmpty) sets ++= selectionSet
selectionSet.foreach {
case f: Field => loop(f.selectionSet)
case f: InlineFragment => loop(f.selectionSet)
@@ -227,17 +229,27 @@ object Validator {

private def collectDirectives(
selectionSet: List[Selection]
): IO[ValidationError, List[(Directive, __DirectiveLocation)]] =
IO.foreach(selectionSet) {
case FragmentSpread(_, directives) =>
checkDirectivesUniqueness(directives).as(directives.map((_, __DirectiveLocation.FRAGMENT_SPREAD)))
case Field(_, _, _, directives, selectionSet, _) =>
checkDirectivesUniqueness(directives) *>
collectDirectives(selectionSet).map(directives.map((_, __DirectiveLocation.FIELD)) ++ _)
case InlineFragment(_, directives, selectionSet) =>
checkDirectivesUniqueness(directives) *>
collectDirectives(selectionSet).map(directives.map((_, __DirectiveLocation.INLINE_FRAGMENT)) ++ _)
}.map(_.flatten)
): IO[ValidationError, List[(Directive, __DirectiveLocation)]] = {
val builder = List.newBuilder[List[(Directive, __DirectiveLocation)]]

def loop(selectionSet: List[Selection]): Unit =
selectionSet.foreach {
case FragmentSpread(_, directives) =>
if (directives.nonEmpty)
builder += directives.map((_, __DirectiveLocation.FRAGMENT_SPREAD))
case Field(_, _, _, directives, selectionSet, _) =>
if (directives.nonEmpty)
builder += directives.map((_, __DirectiveLocation.FIELD))
loop(selectionSet)
case InlineFragment(_, directives, selectionSet) =>
if (directives.nonEmpty)
builder += directives.map((_, __DirectiveLocation.INLINE_FRAGMENT))
loop(selectionSet)
}
loop(selectionSet)
val directiveLists = builder.result()
IO.foreach_(directiveLists)(list => checkDirectivesUniqueness(list.map(_._1))).as(directiveLists.flatten)
}

private def checkDirectivesUniqueness(directives: List[Directive]): IO[ValidationError, Unit] =
IO.whenCase(directives.groupBy(_.name).find { case (_, v) => v.length > 1 }) { case Some((name, _)) =>
20 changes: 11 additions & 9 deletions core/src/test/scala/caliban/execution/ExecutionSpec.scala
Original file line number Diff line number Diff line change
@@ -357,8 +357,8 @@ object ExecutionSpec extends DefaultRunnableSpec {
for {
interpreter <- api.interpreter
result <- interpreter.mapError(_ => "my custom error").execute(query)
} yield assert(result.errors)(equalTo(List("my custom error"))) &&
assert(result.asJson.noSpaces)(equalTo("""{"data":null,"errors":[{"message":"my custom error"}]}"""))
} yield assertTrue(result.errors == List("my custom error")) &&
assertTrue(result.asJson.noSpaces == """{"data":null,"errors":[{"message":"my custom error"}]}""")
},
testM("customErrorEffectSchema") {
import io.circe.syntax._
@@ -689,11 +689,10 @@ object ExecutionSpec extends DefaultRunnableSpec {
interpreter
.flatMap(_.execute(query))
.map(result =>
assert(result.data.toString)(
equalTo("""{"user1":{"name":"user","friends":["friend"]},"user2":null}""")
) &&
assert(result.errors.collectFirst { case e: ExecutionError => e }.map(_.path))(
isSome(equalTo(List(Left("user2"), Left("friends"))))
assertTrue(result.data.toString == """{"user1":{"name":"user","friends":["friend"]},"user2":null}""") &&
assertTrue(
result.errors.collectFirst { case e: ExecutionError => e }.map(_.path).get ==
List(Left("user2"), Left("friends"))
)
)
},
@@ -715,7 +714,7 @@ object ExecutionSpec extends DefaultRunnableSpec {
|}""".stripMargin
interpreter
.flatMap(_.execute(query))
.map(result => assert(result.data.toString)(equalTo("""{"user":null}""")))
.map(result => assertTrue(result.data.toString == """{"user":null}"""))
},
testM("failure in ArgBuilder, non optional field") {
case class UserArgs(id: Int)
@@ -733,7 +732,7 @@ object ExecutionSpec extends DefaultRunnableSpec {
|}""".stripMargin
interpreter
.flatMap(_.execute(query))
.map(result => assert(result.data.toString)(equalTo("""null""")))
.map(result => assertTrue(result.data.toString == """null"""))
},
testM("die inside a nullable list") {
case class Queries(test: List[Task[String]])
@@ -871,6 +870,9 @@ object ExecutionSpec extends DefaultRunnableSpec {
| ... on Human {
| height
| }
| ... on Human {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this intentionally to make sure the result has the field only once.

| height
| }
| ... on Droid {
| primaryFunction
| }