Skip to content

Commit

Permalink
Fix fields merging and improve performance (#1199)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostdogpr authored Dec 12, 2021
1 parent 868eba0 commit 6802417
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 74 deletions.
31 changes: 5 additions & 26 deletions core/src/main/scala/caliban/execution/Executor.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package caliban.execution

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import caliban.CalibanError.ExecutionError
import caliban.ResponseValue._
import caliban.Value._
Expand Down Expand Up @@ -52,7 +51,7 @@ object Executor {
value match {
case EnumValue(v) =>
// special case of an hybrid union containing case objects, those should return an object instead of a string
val obj = mergeFields(currentField, v).collectFirst {
val obj = filterFields(currentField, v).collectFirst {
case f: Field if f.name == "__typename" =>
ObjectValue(List(f.alias.getOrElse(f.name) -> StringValue(v)))
case f: Field if f.name == "_" =>
Expand All @@ -71,8 +70,8 @@ object Executor {
Types.listOf(currentField.fieldType).fold(false)(_.isNullable)
)
case ObjectStep(objectName, fields) =>
val mergedFields = mergeFields(currentField, objectName)
val items = mergedFields.map {
val filteredFields = filterFields(currentField, objectName)
val items = filteredFields.map {
case f @ Field(name @ "__typename", _, _, alias, _, _, _, _, directives) =>
(alias.getOrElse(name), PureStep(StringValue(objectName)), fieldInfo(f, path, directives))
case f @ Field(name, _, _, alias, _, _, args, _, directives) =>
Expand Down Expand Up @@ -173,28 +172,8 @@ object Executor {
private[caliban] def fail(error: CalibanError): UIO[GraphQLResponse[CalibanError]] =
IO.succeed(GraphQLResponse(NullValue, List(error)))

private[caliban] def mergeFields(field: Field, typeName: String): List[Field] = {
// ugly mutable code but it's worth it for the speed ;)
val array = ArrayBuffer.empty[Field]
val map = collection.mutable.Map.empty[String, Int]

field.fields.foreach { field =>
if (field.condition.forall(_.contains(typeName))) {
val name = field.alias.getOrElse(field.name)
map.get(name) match {
case None =>
// first time we see this field, add it to the array
array += field
case Some(index) =>
// field already existed, merge it
val f = array(index)
array(index) = f.copy(fields = f.fields ::: field.fields)
}
}
}

array.toList
}
private[caliban] def filterFields(field: Field, typeName: String): List[Field] =
field.fields.filter(_.condition.forall(_.contains(typeName)))

private def fieldInfo(field: Field, path: List[Either[String, Int]], fieldDirectives: List[Directive]): FieldInfo =
FieldInfo(field.alias.getOrElse(field.name), field, path, fieldDirectives)
Expand Down
71 changes: 53 additions & 18 deletions core/src/main/scala/caliban/execution/Field.scala
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
package caliban.execution

import caliban.{ InputValue, Value }
import caliban.Value.{ BooleanValue, NullValue }
import scala.collection.mutable.ArrayBuffer
import caliban.Value.BooleanValue
import caliban.introspection.adt.{ __DeprecatedArgs, __Type }
import caliban.parsing.SourceMapper
import caliban.parsing.adt.Definition.ExecutableDefinition.FragmentDefinition
import caliban.parsing.adt.Selection.{ Field => F, FragmentSpread, InlineFragment }
import caliban.parsing.adt.{ Directive, LocationInfo, Selection, VariableDefinition }
import caliban.schema.{ RootType, Types }
import caliban.{ InputValue, Value }

case class Field(
name: String,
fieldType: __Type,
parentType: Option[__Type],
alias: Option[String] = None,
fields: List[Field] = Nil,
condition: Option[List[String]] = None,
condition: Option[Set[String]] = None,
arguments: Map[String, InputValue] = Map(),
_locationInfo: () => LocationInfo = () => LocationInfo.origin,
directives: List[Directive] = List.empty
Expand All @@ -35,7 +36,33 @@ object Field {
rootType: RootType
): Field = {
def loop(selectionSet: List[Selection], fieldType: __Type): Field = {
val fieldList = List.newBuilder[Field]
val fieldList = ArrayBuffer.empty[Field]
val map = collection.mutable.Map.empty[String, Int]
var fieldIndex = 0

def addField(f: Field): Unit = {
val name = f.alias.getOrElse(f.name)
map.get(name) match {
case None =>
// first time we see this field, add it to the array
fieldList += f
map.update(name, fieldIndex)
fieldIndex = fieldIndex + 1
case Some(index) =>
// field already existed, merge it
val existing = fieldList(index)
fieldList(index) = existing.copy(
fields = existing.fields ::: f.fields,
condition = (existing.condition, f.condition) match {
case (Some(v1), Some(v2)) => if (v1 == v2) existing.condition else Some(v1 ++ v2)
case (Some(_), None) => existing.condition
case (None, Some(_)) => f.condition
case (None, None) => None
}
)
}
}

val innerType = Types.innerType(fieldType)
selectionSet.foreach {
case F(alias, name, arguments, directives, selectionSet, index)
Expand All @@ -49,7 +76,8 @@ object Field {
val t = selected.fold(Types.string)(_.`type`()) // default only case where it's not found is __typename

val field = loop(selectionSet, t)
fieldList +=

addField(
Field(
name,
t,
Expand All @@ -61,32 +89,38 @@ object Field {
() => sourceMapper.getLocation(index),
directives ++ schemaDirectives
)
)
case FragmentSpread(name, directives) if checkDirectives(directives, variableValues) =>
fragments
.get(name)
.foreach { f =>
val t =
innerType.possibleTypes.flatMap(_.find(_.name.contains(f.typeCondition.name))).getOrElse(fieldType)
fieldList ++= loop(f.selectionSet, t).fields.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(f.typeCondition.name, rootType))
)
loop(f.selectionSet, t).fields
.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(f.typeCondition.name, rootType))
)
.foreach(addField)
}
case InlineFragment(typeCondition, directives, selectionSet) if checkDirectives(directives, variableValues) =>
val t = innerType.possibleTypes
.flatMap(_.find(_.name.exists(typeCondition.map(_.name).contains)))
.getOrElse(fieldType)
val field = loop(selectionSet, t)
typeCondition match {
case None => fieldList ++= field.fields
case None => if (field.fields.nonEmpty) fieldList ++= field.fields
case Some(typeName) =>
fieldList ++= field.fields.map(field =>
if (field.condition.isDefined) field else field.copy(condition = subtypeNames(typeName.name, rootType))
)
field.fields
.map(field =>
if (field.condition.isDefined) field
else field.copy(condition = subtypeNames(typeName.name, rootType))
)
.foreach(addField)
}
case _ =>
}
Field("", fieldType, None, fields = fieldList.result())
Field("", fieldType, None, fields = fieldList.toList)
}

loop(selectionSet, fieldType).copy(directives = directives)
Expand Down Expand Up @@ -114,13 +148,14 @@ object Field {
arguments.flatMap { case (k, v) => resolveVariable(v).map(k -> _) }
}

private def subtypeNames(typeName: String, rootType: RootType): Option[List[String]] =
private def subtypeNames(typeName: String, rootType: RootType): Option[Set[String]] =
rootType.types
.get(typeName)
.map(t =>
typeName ::
t.possibleTypes
.fold(List.empty[String])(_.flatMap(_.name.map(subtypeNames(_, rootType).getOrElse(Nil))).flatten)
t.possibleTypes
.fold(Set.empty[String])(
_.map(_.name.map(subtypeNames(_, rootType).getOrElse(Set.empty))).toSet.flatten.flatten
) + typeName
)

private def checkDirectives(directives: List[Directive], variableValues: Map[String, InputValue]): Boolean =
Expand Down
54 changes: 33 additions & 21 deletions core/src/main/scala/caliban/validation/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,25 @@ object Validator {
private def collectVariablesUsed(context: Context, selectionSet: List[Selection]): Set[String] = {
def collectValues(selectionSet: List[Selection]): List[InputValue] = {
// ugly mutable code but it's worth it for the speed ;)
val inputValues = List.newBuilder[InputValue]
val inputValues = List.newBuilder[InputValue]
def add(list: Iterable[InputValue]) = if (list.nonEmpty) inputValues ++= list

selectionSet.foreach {
case FragmentSpread(name, directives) =>
directives.foreach(inputValues ++= _.arguments.values)
directives.foreach(d => add(d.arguments.values))
context.fragments
.get(name)
.foreach { f =>
f.directives.foreach(inputValues ++= _.arguments.values)
inputValues ++= collectValues(f.selectionSet)
f.directives.foreach(d => add(d.arguments.values))
add(collectValues(f.selectionSet))
}
case Field(_, _, arguments, directives, selectionSet, _) =>
inputValues ++= arguments.values
directives.foreach(inputValues ++= _.arguments.values)
inputValues ++= collectValues(selectionSet)
add(arguments.values)
directives.foreach(d => add(d.arguments.values))
add(collectValues(selectionSet))
case InlineFragment(_, directives, selectionSet) =>
directives.foreach(inputValues ++= _.arguments.values)
inputValues ++= collectValues(selectionSet)
directives.foreach(d => add(d.arguments.values))
add(collectValues(selectionSet))
}
inputValues.result()
}
Expand All @@ -197,7 +199,7 @@ object Validator {
private def collectSelectionSets(selectionSet: List[Selection]): List[Selection] = {
val sets = List.newBuilder[Selection]
def loop(selectionSet: List[Selection]): Unit = {
sets ++= selectionSet
if (selectionSet.nonEmpty) sets ++= selectionSet
selectionSet.foreach {
case f: Field => loop(f.selectionSet)
case f: InlineFragment => loop(f.selectionSet)
Expand Down Expand Up @@ -227,17 +229,27 @@ object Validator {

private def collectDirectives(
selectionSet: List[Selection]
): IO[ValidationError, List[(Directive, __DirectiveLocation)]] =
IO.foreach(selectionSet) {
case FragmentSpread(_, directives) =>
checkDirectivesUniqueness(directives).as(directives.map((_, __DirectiveLocation.FRAGMENT_SPREAD)))
case Field(_, _, _, directives, selectionSet, _) =>
checkDirectivesUniqueness(directives) *>
collectDirectives(selectionSet).map(directives.map((_, __DirectiveLocation.FIELD)) ++ _)
case InlineFragment(_, directives, selectionSet) =>
checkDirectivesUniqueness(directives) *>
collectDirectives(selectionSet).map(directives.map((_, __DirectiveLocation.INLINE_FRAGMENT)) ++ _)
}.map(_.flatten)
): IO[ValidationError, List[(Directive, __DirectiveLocation)]] = {
val builder = List.newBuilder[List[(Directive, __DirectiveLocation)]]

def loop(selectionSet: List[Selection]): Unit =
selectionSet.foreach {
case FragmentSpread(_, directives) =>
if (directives.nonEmpty)
builder += directives.map((_, __DirectiveLocation.FRAGMENT_SPREAD))
case Field(_, _, _, directives, selectionSet, _) =>
if (directives.nonEmpty)
builder += directives.map((_, __DirectiveLocation.FIELD))
loop(selectionSet)
case InlineFragment(_, directives, selectionSet) =>
if (directives.nonEmpty)
builder += directives.map((_, __DirectiveLocation.INLINE_FRAGMENT))
loop(selectionSet)
}
loop(selectionSet)
val directiveLists = builder.result()
IO.foreach_(directiveLists)(list => checkDirectivesUniqueness(list.map(_._1))).as(directiveLists.flatten)
}

private def checkDirectivesUniqueness(directives: List[Directive]): IO[ValidationError, Unit] =
IO.whenCase(directives.groupBy(_.name).find { case (_, v) => v.length > 1 }) { case Some((name, _)) =>
Expand Down
20 changes: 11 additions & 9 deletions core/src/test/scala/caliban/execution/ExecutionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ object ExecutionSpec extends DefaultRunnableSpec {
for {
interpreter <- api.interpreter
result <- interpreter.mapError(_ => "my custom error").execute(query)
} yield assert(result.errors)(equalTo(List("my custom error"))) &&
assert(result.asJson.noSpaces)(equalTo("""{"data":null,"errors":[{"message":"my custom error"}]}"""))
} yield assertTrue(result.errors == List("my custom error")) &&
assertTrue(result.asJson.noSpaces == """{"data":null,"errors":[{"message":"my custom error"}]}""")
},
testM("customErrorEffectSchema") {
import io.circe.syntax._
Expand Down Expand Up @@ -689,11 +689,10 @@ object ExecutionSpec extends DefaultRunnableSpec {
interpreter
.flatMap(_.execute(query))
.map(result =>
assert(result.data.toString)(
equalTo("""{"user1":{"name":"user","friends":["friend"]},"user2":null}""")
) &&
assert(result.errors.collectFirst { case e: ExecutionError => e }.map(_.path))(
isSome(equalTo(List(Left("user2"), Left("friends"))))
assertTrue(result.data.toString == """{"user1":{"name":"user","friends":["friend"]},"user2":null}""") &&
assertTrue(
result.errors.collectFirst { case e: ExecutionError => e }.map(_.path).get ==
List(Left("user2"), Left("friends"))
)
)
},
Expand All @@ -715,7 +714,7 @@ object ExecutionSpec extends DefaultRunnableSpec {
|}""".stripMargin
interpreter
.flatMap(_.execute(query))
.map(result => assert(result.data.toString)(equalTo("""{"user":null}""")))
.map(result => assertTrue(result.data.toString == """{"user":null}"""))
},
testM("failure in ArgBuilder, non optional field") {
case class UserArgs(id: Int)
Expand All @@ -733,7 +732,7 @@ object ExecutionSpec extends DefaultRunnableSpec {
|}""".stripMargin
interpreter
.flatMap(_.execute(query))
.map(result => assert(result.data.toString)(equalTo("""null""")))
.map(result => assertTrue(result.data.toString == """null"""))
},
testM("die inside a nullable list") {
case class Queries(test: List[Task[String]])
Expand Down Expand Up @@ -871,6 +870,9 @@ object ExecutionSpec extends DefaultRunnableSpec {
| ... on Human {
| height
| }
| ... on Human {
| height
| }
| ... on Droid {
| primaryFunction
| }
Expand Down

0 comments on commit 6802417

Please sign in to comment.