From 68024178a160060bbcc58efd8f13d434bcbda2e0 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Sun, 12 Dec 2021 09:10:16 +0900 Subject: [PATCH] Fix fields merging and improve performance (#1199) --- .../scala/caliban/execution/Executor.scala | 31 ++------ .../main/scala/caliban/execution/Field.scala | 71 ++++++++++++++----- .../scala/caliban/validation/Validator.scala | 54 ++++++++------ .../caliban/execution/ExecutionSpec.scala | 20 +++--- 4 files changed, 102 insertions(+), 74 deletions(-) diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index bcaf04b97..6e6992d68 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -1,7 +1,6 @@ package caliban.execution import scala.annotation.tailrec -import scala.collection.mutable.ArrayBuffer import caliban.CalibanError.ExecutionError import caliban.ResponseValue._ import caliban.Value._ @@ -52,7 +51,7 @@ object Executor { value match { case EnumValue(v) => // special case of an hybrid union containing case objects, those should return an object instead of a string - val obj = mergeFields(currentField, v).collectFirst { + val obj = filterFields(currentField, v).collectFirst { case f: Field if f.name == "__typename" => ObjectValue(List(f.alias.getOrElse(f.name) -> StringValue(v))) case f: Field if f.name == "_" => @@ -71,8 +70,8 @@ object Executor { Types.listOf(currentField.fieldType).fold(false)(_.isNullable) ) case ObjectStep(objectName, fields) => - val mergedFields = mergeFields(currentField, objectName) - val items = mergedFields.map { + val filteredFields = filterFields(currentField, objectName) + val items = filteredFields.map { case f @ Field(name @ "__typename", _, _, alias, _, _, _, _, directives) => (alias.getOrElse(name), PureStep(StringValue(objectName)), fieldInfo(f, path, directives)) case f @ Field(name, _, _, alias, _, _, args, _, directives) => @@ -173,28 +172,8 @@ object Executor { private[caliban] def fail(error: CalibanError): UIO[GraphQLResponse[CalibanError]] = IO.succeed(GraphQLResponse(NullValue, List(error))) - private[caliban] def mergeFields(field: Field, typeName: String): List[Field] = { - // ugly mutable code but it's worth it for the speed ;) - val array = ArrayBuffer.empty[Field] - val map = collection.mutable.Map.empty[String, Int] - - field.fields.foreach { field => - if (field.condition.forall(_.contains(typeName))) { - val name = field.alias.getOrElse(field.name) - map.get(name) match { - case None => - // first time we see this field, add it to the array - array += field - case Some(index) => - // field already existed, merge it - val f = array(index) - array(index) = f.copy(fields = f.fields ::: field.fields) - } - } - } - - array.toList - } + private[caliban] def filterFields(field: Field, typeName: String): List[Field] = + field.fields.filter(_.condition.forall(_.contains(typeName))) private def fieldInfo(field: Field, path: List[Either[String, Int]], fieldDirectives: List[Directive]): FieldInfo = FieldInfo(field.alias.getOrElse(field.name), field, path, fieldDirectives) diff --git a/core/src/main/scala/caliban/execution/Field.scala b/core/src/main/scala/caliban/execution/Field.scala index 88320c8c5..73b445083 100644 --- a/core/src/main/scala/caliban/execution/Field.scala +++ b/core/src/main/scala/caliban/execution/Field.scala @@ -1,13 +1,14 @@ package caliban.execution -import caliban.{ InputValue, Value } -import caliban.Value.{ BooleanValue, NullValue } +import scala.collection.mutable.ArrayBuffer +import caliban.Value.BooleanValue import caliban.introspection.adt.{ __DeprecatedArgs, __Type } import caliban.parsing.SourceMapper import caliban.parsing.adt.Definition.ExecutableDefinition.FragmentDefinition import caliban.parsing.adt.Selection.{ Field => F, FragmentSpread, InlineFragment } import caliban.parsing.adt.{ Directive, LocationInfo, Selection, VariableDefinition } import caliban.schema.{ RootType, Types } +import caliban.{ InputValue, Value } case class Field( name: String, @@ -15,7 +16,7 @@ case class Field( parentType: Option[__Type], alias: Option[String] = None, fields: List[Field] = Nil, - condition: Option[List[String]] = None, + condition: Option[Set[String]] = None, arguments: Map[String, InputValue] = Map(), _locationInfo: () => LocationInfo = () => LocationInfo.origin, directives: List[Directive] = List.empty @@ -35,7 +36,33 @@ object Field { rootType: RootType ): Field = { def loop(selectionSet: List[Selection], fieldType: __Type): Field = { - val fieldList = List.newBuilder[Field] + val fieldList = ArrayBuffer.empty[Field] + val map = collection.mutable.Map.empty[String, Int] + var fieldIndex = 0 + + def addField(f: Field): Unit = { + val name = f.alias.getOrElse(f.name) + map.get(name) match { + case None => + // first time we see this field, add it to the array + fieldList += f + map.update(name, fieldIndex) + fieldIndex = fieldIndex + 1 + case Some(index) => + // field already existed, merge it + val existing = fieldList(index) + fieldList(index) = existing.copy( + fields = existing.fields ::: f.fields, + condition = (existing.condition, f.condition) match { + case (Some(v1), Some(v2)) => if (v1 == v2) existing.condition else Some(v1 ++ v2) + case (Some(_), None) => existing.condition + case (None, Some(_)) => f.condition + case (None, None) => None + } + ) + } + } + val innerType = Types.innerType(fieldType) selectionSet.foreach { case F(alias, name, arguments, directives, selectionSet, index) @@ -49,7 +76,8 @@ object Field { val t = selected.fold(Types.string)(_.`type`()) // default only case where it's not found is __typename val field = loop(selectionSet, t) - fieldList += + + addField( Field( name, t, @@ -61,16 +89,19 @@ object Field { () => sourceMapper.getLocation(index), directives ++ schemaDirectives ) + ) case FragmentSpread(name, directives) if checkDirectives(directives, variableValues) => fragments .get(name) .foreach { f => val t = innerType.possibleTypes.flatMap(_.find(_.name.contains(f.typeCondition.name))).getOrElse(fieldType) - fieldList ++= loop(f.selectionSet, t).fields.map(field => - if (field.condition.isDefined) field - else field.copy(condition = subtypeNames(f.typeCondition.name, rootType)) - ) + loop(f.selectionSet, t).fields + .map(field => + if (field.condition.isDefined) field + else field.copy(condition = subtypeNames(f.typeCondition.name, rootType)) + ) + .foreach(addField) } case InlineFragment(typeCondition, directives, selectionSet) if checkDirectives(directives, variableValues) => val t = innerType.possibleTypes @@ -78,15 +109,18 @@ object Field { .getOrElse(fieldType) val field = loop(selectionSet, t) typeCondition match { - case None => fieldList ++= field.fields + case None => if (field.fields.nonEmpty) fieldList ++= field.fields case Some(typeName) => - fieldList ++= field.fields.map(field => - if (field.condition.isDefined) field else field.copy(condition = subtypeNames(typeName.name, rootType)) - ) + field.fields + .map(field => + if (field.condition.isDefined) field + else field.copy(condition = subtypeNames(typeName.name, rootType)) + ) + .foreach(addField) } case _ => } - Field("", fieldType, None, fields = fieldList.result()) + Field("", fieldType, None, fields = fieldList.toList) } loop(selectionSet, fieldType).copy(directives = directives) @@ -114,13 +148,14 @@ object Field { arguments.flatMap { case (k, v) => resolveVariable(v).map(k -> _) } } - private def subtypeNames(typeName: String, rootType: RootType): Option[List[String]] = + private def subtypeNames(typeName: String, rootType: RootType): Option[Set[String]] = rootType.types .get(typeName) .map(t => - typeName :: - t.possibleTypes - .fold(List.empty[String])(_.flatMap(_.name.map(subtypeNames(_, rootType).getOrElse(Nil))).flatten) + t.possibleTypes + .fold(Set.empty[String])( + _.map(_.name.map(subtypeNames(_, rootType).getOrElse(Set.empty))).toSet.flatten.flatten + ) + typeName ) private def checkDirectives(directives: List[Directive], variableValues: Map[String, InputValue]): Boolean = diff --git a/core/src/main/scala/caliban/validation/Validator.scala b/core/src/main/scala/caliban/validation/Validator.scala index 493df9bbe..df7178f4c 100644 --- a/core/src/main/scala/caliban/validation/Validator.scala +++ b/core/src/main/scala/caliban/validation/Validator.scala @@ -161,23 +161,25 @@ object Validator { private def collectVariablesUsed(context: Context, selectionSet: List[Selection]): Set[String] = { def collectValues(selectionSet: List[Selection]): List[InputValue] = { // ugly mutable code but it's worth it for the speed ;) - val inputValues = List.newBuilder[InputValue] + val inputValues = List.newBuilder[InputValue] + def add(list: Iterable[InputValue]) = if (list.nonEmpty) inputValues ++= list + selectionSet.foreach { case FragmentSpread(name, directives) => - directives.foreach(inputValues ++= _.arguments.values) + directives.foreach(d => add(d.arguments.values)) context.fragments .get(name) .foreach { f => - f.directives.foreach(inputValues ++= _.arguments.values) - inputValues ++= collectValues(f.selectionSet) + f.directives.foreach(d => add(d.arguments.values)) + add(collectValues(f.selectionSet)) } case Field(_, _, arguments, directives, selectionSet, _) => - inputValues ++= arguments.values - directives.foreach(inputValues ++= _.arguments.values) - inputValues ++= collectValues(selectionSet) + add(arguments.values) + directives.foreach(d => add(d.arguments.values)) + add(collectValues(selectionSet)) case InlineFragment(_, directives, selectionSet) => - directives.foreach(inputValues ++= _.arguments.values) - inputValues ++= collectValues(selectionSet) + directives.foreach(d => add(d.arguments.values)) + add(collectValues(selectionSet)) } inputValues.result() } @@ -197,7 +199,7 @@ object Validator { private def collectSelectionSets(selectionSet: List[Selection]): List[Selection] = { val sets = List.newBuilder[Selection] def loop(selectionSet: List[Selection]): Unit = { - sets ++= selectionSet + if (selectionSet.nonEmpty) sets ++= selectionSet selectionSet.foreach { case f: Field => loop(f.selectionSet) case f: InlineFragment => loop(f.selectionSet) @@ -227,17 +229,27 @@ object Validator { private def collectDirectives( selectionSet: List[Selection] - ): IO[ValidationError, List[(Directive, __DirectiveLocation)]] = - IO.foreach(selectionSet) { - case FragmentSpread(_, directives) => - checkDirectivesUniqueness(directives).as(directives.map((_, __DirectiveLocation.FRAGMENT_SPREAD))) - case Field(_, _, _, directives, selectionSet, _) => - checkDirectivesUniqueness(directives) *> - collectDirectives(selectionSet).map(directives.map((_, __DirectiveLocation.FIELD)) ++ _) - case InlineFragment(_, directives, selectionSet) => - checkDirectivesUniqueness(directives) *> - collectDirectives(selectionSet).map(directives.map((_, __DirectiveLocation.INLINE_FRAGMENT)) ++ _) - }.map(_.flatten) + ): IO[ValidationError, List[(Directive, __DirectiveLocation)]] = { + val builder = List.newBuilder[List[(Directive, __DirectiveLocation)]] + + def loop(selectionSet: List[Selection]): Unit = + selectionSet.foreach { + case FragmentSpread(_, directives) => + if (directives.nonEmpty) + builder += directives.map((_, __DirectiveLocation.FRAGMENT_SPREAD)) + case Field(_, _, _, directives, selectionSet, _) => + if (directives.nonEmpty) + builder += directives.map((_, __DirectiveLocation.FIELD)) + loop(selectionSet) + case InlineFragment(_, directives, selectionSet) => + if (directives.nonEmpty) + builder += directives.map((_, __DirectiveLocation.INLINE_FRAGMENT)) + loop(selectionSet) + } + loop(selectionSet) + val directiveLists = builder.result() + IO.foreach_(directiveLists)(list => checkDirectivesUniqueness(list.map(_._1))).as(directiveLists.flatten) + } private def checkDirectivesUniqueness(directives: List[Directive]): IO[ValidationError, Unit] = IO.whenCase(directives.groupBy(_.name).find { case (_, v) => v.length > 1 }) { case Some((name, _)) => diff --git a/core/src/test/scala/caliban/execution/ExecutionSpec.scala b/core/src/test/scala/caliban/execution/ExecutionSpec.scala index 58581e00a..6f531e865 100644 --- a/core/src/test/scala/caliban/execution/ExecutionSpec.scala +++ b/core/src/test/scala/caliban/execution/ExecutionSpec.scala @@ -357,8 +357,8 @@ object ExecutionSpec extends DefaultRunnableSpec { for { interpreter <- api.interpreter result <- interpreter.mapError(_ => "my custom error").execute(query) - } yield assert(result.errors)(equalTo(List("my custom error"))) && - assert(result.asJson.noSpaces)(equalTo("""{"data":null,"errors":[{"message":"my custom error"}]}""")) + } yield assertTrue(result.errors == List("my custom error")) && + assertTrue(result.asJson.noSpaces == """{"data":null,"errors":[{"message":"my custom error"}]}""") }, testM("customErrorEffectSchema") { import io.circe.syntax._ @@ -689,11 +689,10 @@ object ExecutionSpec extends DefaultRunnableSpec { interpreter .flatMap(_.execute(query)) .map(result => - assert(result.data.toString)( - equalTo("""{"user1":{"name":"user","friends":["friend"]},"user2":null}""") - ) && - assert(result.errors.collectFirst { case e: ExecutionError => e }.map(_.path))( - isSome(equalTo(List(Left("user2"), Left("friends")))) + assertTrue(result.data.toString == """{"user1":{"name":"user","friends":["friend"]},"user2":null}""") && + assertTrue( + result.errors.collectFirst { case e: ExecutionError => e }.map(_.path).get == + List(Left("user2"), Left("friends")) ) ) }, @@ -715,7 +714,7 @@ object ExecutionSpec extends DefaultRunnableSpec { |}""".stripMargin interpreter .flatMap(_.execute(query)) - .map(result => assert(result.data.toString)(equalTo("""{"user":null}"""))) + .map(result => assertTrue(result.data.toString == """{"user":null}""")) }, testM("failure in ArgBuilder, non optional field") { case class UserArgs(id: Int) @@ -733,7 +732,7 @@ object ExecutionSpec extends DefaultRunnableSpec { |}""".stripMargin interpreter .flatMap(_.execute(query)) - .map(result => assert(result.data.toString)(equalTo("""null"""))) + .map(result => assertTrue(result.data.toString == """null""")) }, testM("die inside a nullable list") { case class Queries(test: List[Task[String]]) @@ -871,6 +870,9 @@ object ExecutionSpec extends DefaultRunnableSpec { | ... on Human { | height | } + | ... on Human { + | height + | } | ... on Droid { | primaryFunction | }