From 0cd2855640eb399df731a98f3db963d5bb4d5415 Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Fri, 15 Dec 2023 04:52:19 +0100 Subject: [PATCH] Optimize fragment validation (#2039) * Optimize object field execution & validation of fragments * Optimize Validator --- build.sbt | 13 +- .../scala/caliban/execution/Executor.scala | 32 +-- .../main/scala/caliban/execution/Field.scala | 69 +++--- .../caliban/introspection/adt/__Field.scala | 2 + .../caliban/introspection/adt/__Type.scala | 16 +- .../scala/caliban/parsing/adt/Selection.scala | 4 +- .../validation/FragmentValidator.scala | 26 +-- .../main/scala/caliban/validation/Utils.scala | 37 ++-- .../scala/caliban/validation/Validator.scala | 207 +++++++++--------- .../caliban/validation/ValueValidator.scala | 15 +- .../scala/caliban/validation/package.scala | 4 +- 11 files changed, 216 insertions(+), 209 deletions(-) diff --git a/build.sbt b/build.sbt index e279e5aa90..ee95cc523e 100644 --- a/build.sbt +++ b/build.sbt @@ -669,22 +669,13 @@ lazy val commonSettings = Def.settings( }) ) -lazy val enforceMimaCompatibility = true // Enable / disable failing CI on binary incompatibilities +lazy val enforceMimaCompatibility = false // Enable / disable failing CI on binary incompatibilities lazy val enableMimaSettingsJVM = Def.settings( mimaFailOnProblem := enforceMimaCompatibility, mimaPreviousArtifacts := previousStableVersion.value.map(organization.value %% moduleName.value % _).toSet, - mimaBinaryIssueFilters ++= Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]("caliban.schema.Step#ObjectStep*"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("caliban.schema.Step#ObjectStep*"), - ProblemFilters.exclude[DirectMissingMethodProblem]("caliban.schema.Annotations*"), - ProblemFilters.exclude[MissingTypesProblem]("caliban.schema.Annotations*"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("caliban.schema.Annotations*"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("caliban.Quick*"), - ProblemFilters.exclude[DirectMissingMethodProblem]("caliban.Quick*"), - ProblemFilters.exclude[DirectMissingMethodProblem]("caliban.quick.package*") - ) + mimaBinaryIssueFilters ++= Seq() ) lazy val enableMimaSettingsJS = diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index 2eea4267d6..2739756e41 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -64,20 +64,26 @@ object Executor { ): ReducedStep[R] = { def reduceObjectStep(objectName: String, getFieldStep: String => Step[R]): ReducedStep[R] = { + def reduceField(f: Field): (String, ReducedStep[R], FieldInfo) = { + val field = + if (f.name == "__typename") PureStep(StringValue(objectName)) + else reduceStep(getFieldStep(f.name), f, f.arguments, Left(f.aliasedName) :: path) + (f.aliasedName, field, fieldInfo(f, path, f.directives)) + } + val filteredFields = mergeFields(currentField, objectName) - val (deferred, eager) = filteredFields.partitionMap { - case f @ Field("__typename", _, _, _, _, _, _, directives, _, _, _) => - Right((f.aliasedName, PureStep(StringValue(objectName)), fieldInfo(f, path, directives))) - case f @ Field(name, _, _, _, _, _, args, directives, _, _, fragment) => - val field = reduceStep(getFieldStep(name), f, args, Left(f.aliasedName) :: path) - val entry = (f.aliasedName, field, fieldInfo(f, path, directives)) - - fragment match { - // The defer spec provides some latitude on how we handle responses. Since it is more performant to return - // pure fields rather than spin up the defer machinery we return pure fields immediately to the caller. - case Some(IsDeferred(label)) if isDeferredEnabled && !field.isPure => Left((label, entry)) - case _ => Right(entry) + val (deferred, eager) = { + if (isDeferredEnabled) { + filteredFields.partitionMap { f => + val entry = reduceField(f) + f.fragment match { + // The defer spec provides some latitude on how we handle responses. Since it is more performant to return + // pure fields rather than spin up the defer machinery we return pure fields immediately to the caller. + case Some(IsDeferred(label)) if !entry._2.isPure => Left((label, entry)) + case _ => Right(entry) + } } + } else (Nil, filteredFields.map(reduceField)) } val eagerReduced = reduceObject(eager, wrapPureValues) @@ -303,7 +309,7 @@ object Executor { for { cache <- Cache.empty - reduced = reduceStep(plan, request.field, Map(), Nil) + reduced = reduceStep(plan, request.field, Map.empty, Nil) response <- runQuery(reduced, cache) } yield response } diff --git a/core/src/main/scala/caliban/execution/Field.scala b/core/src/main/scala/caliban/execution/Field.scala index a1a8c12363..ffd2bb18d5 100644 --- a/core/src/main/scala/caliban/execution/Field.scala +++ b/core/src/main/scala/caliban/execution/Field.scala @@ -121,7 +121,13 @@ object Field { val memoizedFragments = new mutable.HashMap[String, (List[Field], Option[String])]() val variableDefinitionsMap = variableDefinitions.map(v => v.name -> v).toMap - def loop(selectionSet: List[Selection], fieldType: __Type, fragment: Option[Fragment]): List[Field] = { + def loop( + selectionSet: List[Selection], + fieldType: __Type, + fragment: Option[Fragment], + targets: Option[Set[String]], + condition: Option[Set[String]] + ): List[Field] = { val map = new java.util.LinkedHashMap[(String, Option[String]), Field]() def addField(f: Field, condition: Option[String]): Unit = @@ -141,22 +147,22 @@ object Field { val t = selected.fold(Types.string)(_._type) // default only case where it's not found is __typename val fields = - if (selectionSet.nonEmpty) loop(selectionSet, t, None) + if (selectionSet.nonEmpty) loop(selectionSet, t, None, None, None) else Nil // Fragments apply on to the direct children of the fragment spread addField( - Field( + new Field( name, t, Some(innerType), alias, fields, - None, - resolveVariables(arguments, variableDefinitionsMap, variableValues), - resolvedDirectives, - None, - () => sourceMapper.getLocation(index), - fragment + targets = targets, + arguments = resolveVariables(arguments, variableDefinitionsMap, variableValues), + directives = resolvedDirectives, + _condition = condition, + _locationInfo = () => sourceMapper.getLocation(index), + fragment = fragment ), None ) @@ -165,17 +171,16 @@ object Field { val (fields, condition) = memoizedFragments.getOrElseUpdate( name, { val resolvedDirectives = directives.map(resolveDirectiveVariables(variableValues, variableDefinitionsMap)) - - val _fields = if (checkDirectives(resolvedDirectives)) { + val _fields = if (checkDirectives(resolvedDirectives)) { fragments.get(name).map { f => - val t = rootType.types.getOrElse(f.typeCondition.name, fieldType) - val _targets = Some(Set(f.typeCondition.name)) - val _condition = subtypeNames(f.typeCondition.name, rootType) - - loop(f.selectionSet, t, Some(Fragment(Some(name), resolvedDirectives))).map { field => - if (field._condition.isDefined) field - else field.copy(targets = _targets, _condition = _condition) - } -> Some(f.typeCondition.name) + val t = rootType.types.getOrElse(f.typeCondition.name, fieldType) + loop( + f.selectionSet, + t, + fragment = Some(Fragment(Some(name), resolvedDirectives)), + targets = Some(Set(f.typeCondition.name)), + condition = subtypeNames(f.typeCondition.name, rootType) + ) -> Some(f.typeCondition.name) } } else None _fields.getOrElse(Nil -> None) @@ -184,31 +189,25 @@ object Field { fields.foreach(addField(_, condition)) case InlineFragment(typeCondition, directives, selectionSet) => val resolvedDirectives = directives.map(resolveDirectiveVariables(variableValues, variableDefinitionsMap)) - if (checkDirectives(resolvedDirectives)) { - val t = innerType.possibleTypes + val t = innerType.possibleTypes .flatMap(_.find(_.name.exists(typeCondition.map(_.name).contains))) .orElse(typeCondition.flatMap(typeName => rootType.types.get(typeName.name))) .getOrElse(fieldType) - val fields = loop(selectionSet, t, Some(Fragment(None, resolvedDirectives))) - typeCondition match { - case None => fields.map(addField(_, None)) - case Some(typeName) => - val _targets = Some(Set(typeName.name)) - val _condition = subtypeNames(typeName.name, rootType) - fields.foreach { field => - val _field = - if (field._condition.isDefined) field - else field.copy(targets = _targets, _condition = _condition) - addField(_field, Some(typeName.name)) - } - } + val typeName = typeCondition.map(_.name) + loop( + selectionSet, + t, + fragment = Some(Fragment(None, resolvedDirectives)), + targets = typeName.map(Set(_)), + condition = typeName.flatMap(subtypeNames(_, rootType)) + ).foreach(addField(_, typeName)) } } map.values().asScala.toList } - val fields = loop(selectionSet, fieldType, None) + val fields = loop(selectionSet, fieldType, None, None, None) Field("", fieldType, None, fields = fields, directives = directives) } diff --git a/core/src/main/scala/caliban/introspection/adt/__Field.scala b/core/src/main/scala/caliban/introspection/adt/__Field.scala index b6d85d8668..63cbef7305 100644 --- a/core/src/main/scala/caliban/introspection/adt/__Field.scala +++ b/core/src/main/scala/caliban/introspection/adt/__Field.scala @@ -14,6 +14,8 @@ case class __Field( deprecationReason: Option[String] = None, @GQLExcluded directives: Option[List[Directive]] = None ) { + final override lazy val hashCode: Int = super.hashCode() + def toFieldDefinition: FieldDefinition = { val allDirectives = (if (isDeprecated) List( diff --git a/core/src/main/scala/caliban/introspection/adt/__Type.scala b/core/src/main/scala/caliban/introspection/adt/__Type.scala index 7e6c19d57f..0ef41bbfbb 100644 --- a/core/src/main/scala/caliban/introspection/adt/__Type.scala +++ b/core/src/main/scala/caliban/introspection/adt/__Type.scala @@ -22,6 +22,8 @@ case class __Type( @GQLExcluded directives: Option[List[Directive]] = None, @GQLExcluded origin: Option[String] = None ) { self => + final override lazy val hashCode: Int = super.hashCode() + def |+|(that: __Type): __Type = __Type( kind, (name ++ that.name).reduceOption((_, b) => b), @@ -130,8 +132,18 @@ case class __Type( lazy val allEnumValues: List[__EnumValue] = enumValues(__DeprecatedArgs(Some(true))).getOrElse(Nil) - private[caliban] lazy val allFieldsMap: Map[String, __Field] = - allFields.map(f => f.name -> f).toMap + private[caliban] lazy val allFieldsMap: collection.Map[String, __Field] = { + val map = collection.mutable.HashMap.empty[String, __Field] + allFields.foreach(f => map.update(f.name, f)) + map + } lazy val innerType: __Type = Types.innerType(this) + + private[caliban] lazy val possibleTypeNames: Set[String] = + kind match { + case __TypeKind.OBJECT => name.fold(Set.empty[String])(Set(_)) + case __TypeKind.INTERFACE | __TypeKind.UNION => possibleTypes.fold(Set.empty[String])(_.flatMap(_.name).toSet) + case _ => Set.empty + } } diff --git a/core/src/main/scala/caliban/parsing/adt/Selection.scala b/core/src/main/scala/caliban/parsing/adt/Selection.scala index 1a1cf0ac0d..b994a21159 100644 --- a/core/src/main/scala/caliban/parsing/adt/Selection.scala +++ b/core/src/main/scala/caliban/parsing/adt/Selection.scala @@ -3,7 +3,9 @@ package caliban.parsing.adt import caliban.InputValue import caliban.parsing.adt.Type.NamedType -sealed trait Selection +sealed trait Selection { + final override lazy val hashCode: Int = super.hashCode() +} object Selection { diff --git a/core/src/main/scala/caliban/validation/FragmentValidator.scala b/core/src/main/scala/caliban/validation/FragmentValidator.scala index 929cab6c79..f0c1eb4c7c 100644 --- a/core/src/main/scala/caliban/validation/FragmentValidator.scala +++ b/core/src/main/scala/caliban/validation/FragmentValidator.scala @@ -4,12 +4,12 @@ import caliban.CalibanError.ValidationError import caliban.introspection.adt._ import caliban.parsing.adt.Selection import caliban.validation.Utils._ -import caliban.validation.Utils.syntax._ import zio.Chunk -import zio.prelude.EReader +import zio.prelude._ import zio.prelude.fx.ZPure import scala.collection.mutable +import scala.util.hashing.MurmurHash3 object FragmentValidator { def findConflictsWithinSelectionSet( @@ -24,13 +24,13 @@ object FragmentValidator { val groupsCache = mutable.Map.empty[Int, Chunk[Set[SelectedField]]] def sameResponseShapeByName(set: Iterable[Selection]): Chunk[String] = { - val keyHash = set.hashCode() + val keyHash = MurmurHash3.unorderedHash(set) shapeCache.get(keyHash) match { case Some(value) => value case None => val fields = FieldMap(context, parentType, set) val res = Chunk.fromIterable(fields.flatMap { case (name, values) => - cross(values).flatMap { case (f1, f2) => + cross(values, includeIdentity = true).flatMap { case (f1, f2) => if (doTypesConflict(f1.fieldDef._type, f2.fieldDef._type)) { Chunk( s"$name has conflicting types: ${f1.parentType.name.getOrElse("")}.${f1.fieldDef.name} and ${f2.parentType.name @@ -46,7 +46,7 @@ object FragmentValidator { } def sameForCommonParentsByName(set: Iterable[Selection]): Chunk[String] = { - val keyHash = set.hashCode() + val keyHash = MurmurHash3.unorderedHash(set) parentsCache.get(keyHash) match { case Some(value) => value case None => @@ -81,14 +81,14 @@ object FragmentValidator { false def requireSameNameAndArguments(fields: Set[SelectedField]) = - cross(fields).flatMap { case (f1, f2) => + cross(fields, includeIdentity = false).flatMap { case (f1, f2) => if (f1.fieldDef.name != f2.fieldDef.name) { - List( + Some( s"${f1.parentType.name.getOrElse("")}.${f1.fieldDef.name} and ${f2.parentType.name.getOrElse("")}.${f2.fieldDef.name} are different fields." ) } else if (f1.selection.arguments != f2.selection.arguments) - List(s"${f1.fieldDef.name} and ${f2.fieldDef.name} have different arguments") - else List() + Some(s"${f1.fieldDef.name} and ${f2.fieldDef.name} have different arguments") + else None } def groupByCommonParents(fields: Set[SelectedField]): Chunk[Set[SelectedField]] = { @@ -109,13 +109,7 @@ object FragmentValidator { _, _ ) if isConcrete(field.parentType) => - concreteGroups.get(name) match { - case Some(v) => v += field - case None => - val sb = Set.newBuilder ++= abstractGroup - sb += field - concreteGroups.update(name, sb) - } + concreteGroups.getOrElseUpdate(name, Set.newBuilder ++= abstractGroup) += field case _ => () } diff --git a/core/src/main/scala/caliban/validation/Utils.scala b/core/src/main/scala/caliban/validation/Utils.scala index 5caeab4aef..51949a363d 100644 --- a/core/src/main/scala/caliban/validation/Utils.scala +++ b/core/src/main/scala/caliban/validation/Utils.scala @@ -5,6 +5,8 @@ import caliban.introspection.adt.__TypeKind._ import caliban.parsing.adt.Type.NamedType import zio.Chunk +import scala.collection.compat._ + object Utils { def isObjectType(t: __Type): Boolean = t.kind match { @@ -56,22 +58,25 @@ object Utils { def getType(t: NamedType, context: Context): Option[__Type] = context.rootType.types.get(t.name) - def cross[A](a: Iterable[A]): Chunk[(A, A)] = - Chunk.fromIterable(for (xs <- a; ys <- a) yield (xs, ys)) - - def cross[A](a: Iterable[A], b: Iterable[A]): Chunk[(A, A)] = - Chunk.fromIterable(for (xs <- a; ys <- b) yield (xs, ys)) - - object syntax { - implicit class OptionSyntax[+A](val self: Option[A]) extends AnyVal { - def zip[B](that: Option[B]): Option[(A, B)] = - self.flatMap(a => that.map(b => (a, b))) + /** + * For an iterable, produce a Chunk containing tuples of all possible unique combinations, optionally including the identity + */ + def cross[A](a: Iterable[A], includeIdentity: Boolean): Chunk[(A, A)] = { + val ca = Chunk.fromIterable(a) + val size = ca.size + val cb = Chunk.newBuilder[(A, A)] + var i1, i2 = 0 + val modifier = if (includeIdentity) 0 else 1 + while (i1 < size - modifier) { + i2 = i1 + modifier + val l = ca(i1) + while (i2 < size) { + cb += ((l, ca(i2))) + i2 += 1 + } + i1 += 1 } - - implicit class Tuple2Syntax[+A, +B](val self: (Option[A], Option[B])) extends AnyVal { - def mapN[C](f: (A, B) => C): Option[C] = - self._1.flatMap(a => self._2.map(b => f(a, b))) - } - + cb.result() } + } diff --git a/core/src/main/scala/caliban/validation/Validator.scala b/core/src/main/scala/caliban/validation/Validator.scala index d8e3ece40a..a4d33ba868 100644 --- a/core/src/main/scala/caliban/validation/Validator.scala +++ b/core/src/main/scala/caliban/validation/Validator.scala @@ -7,7 +7,6 @@ import caliban.execution.{ ExecutionRequest, Field => F } import caliban.introspection.Introspector import caliban.introspection.adt._ import caliban.introspection.adt.__TypeKind._ -import caliban.parsing.SourceMapper import caliban.parsing.adt.Definition.ExecutableDefinition.{ FragmentDefinition, OperationDefinition } import caliban.parsing.adt.Definition.TypeSystemDefinition.DirectiveDefinition import caliban.parsing.adt.Definition.{ TypeSystemDefinition, TypeSystemExtension } @@ -15,14 +14,14 @@ import caliban.parsing.adt.OperationType._ import caliban.parsing.adt.Selection.{ Field, FragmentSpread, InlineFragment } import caliban.parsing.adt.Type.NamedType import caliban.parsing.adt._ -import caliban.parsing.Parser +import caliban.parsing.{ Parser, SourceMapper } import caliban.rendering.DocumentRenderer import caliban.schema._ import caliban.validation.Utils.isObjectType -import caliban.{ Configurator, InputValue, Value } -import zio.{ IO, ZIO } +import caliban.{ Configurator, InputValue } import zio.prelude._ import zio.prelude.fx.ZPure +import zio.{ IO, ZIO } import scala.annotation.tailrec import scala.collection.mutable @@ -69,6 +68,8 @@ object Validator { validateRootQuery(schema) }.runEither + private val zunit = ZPure.unit[Unit] + private[caliban] def validateType(t: __Type): EReader[Any, ValidationError, Unit] = ZPure.forEach(t.name)(name => checkName(name, s"Type '$name'")) *> (t.kind match { @@ -77,7 +78,7 @@ object Validator { case __TypeKind.INTERFACE => validateInterface(t) case __TypeKind.INPUT_OBJECT => validateInputObject(t) case __TypeKind.OBJECT => validateObject(t) - case _ => ZPure.unit + case _ => zunit }) def failValidation(msg: String, explanatoryText: String): EReader[Any, ValidationError, Nothing] = @@ -161,7 +162,7 @@ object Validator { validateFragments(fragments).flatMap { fragmentMap => val selectionSets = collectSelectionSets(operations.flatMap(_.selectionSet) ++ fragments.flatMap(_.selectionSet)) val context = Context(document, rootType, operations, fragmentMap, selectionSets, variables) - ZPure.foreachDiscard(validations)(identity).provideService(context) as fragmentMap + validateAll(validations)(identity).provideService(context) as fragmentMap } } @@ -304,7 +305,7 @@ object Validator { // it's a minor optimization to short-circuit the length check on a List for the off-chance that list is long (v.lengthCompare(1) > 0) && !directiveDefinitions.get(n).exists(_.exists(_.isRepeatable)) } match { - case None => ZPure.unit + case None => zunit case Some((name, _)) => failValidation( s"Directive '$name' is defined more than once.", @@ -315,7 +316,7 @@ object Validator { lazy val validateDirectives: QueryValidation = ZPure.serviceWithPure { context => for { directives <- collectAllDirectives(context) - _ <- ZPure.foreachDiscard(directives) { case (d, location) => + _ <- validateAll(directives) { case (d, location) => (context.rootType.additionalDirectives ::: Introspector.directives).find(_.name == d.name) match { case None => failValidation( @@ -323,13 +324,8 @@ object Validator { "GraphQL servers define what directives they support. For each usage of a directive, the directive must be available on that server." ) case Some(directive) => - ZPure.foreachDiscard(d.arguments) { case (arg, argValue) => + validateAll(d.arguments) { case (arg, argValue) => directive.allArgs.find(_.name == arg) match { - case None => - failValidation( - s"Argument '$arg' is not defined on directive '${d.name}' ($location).", - "Every argument provided to a field or directive must be defined in the set of possible arguments of that field or directive." - ) case Some(inputValue) => validateInputValues( inputValue, @@ -337,6 +333,11 @@ object Validator { context, s"InputValue '${inputValue.name}' of Directive '${d.name}'" ) + case None => + failValidation( + s"Argument '$arg' is not defined on directive '${d.name}' ($location).", + "Every argument provided to a field or directive must be defined in the set of possible arguments of that field or directive." + ) } } *> ZPure.when(!directive.locations.contains(location))( @@ -353,14 +354,14 @@ object Validator { lazy val validateVariables: QueryValidation = ZPure.serviceWithPure { context => ZPure.foreachDiscard(context.operations)(op => - ZPure.foreachDiscard(op.variableDefinitions.groupBy(_.name)) { case (name, variables) => + validateAll(op.variableDefinitions.groupBy(_.name)) { case (name, variables) => ZPure.when(variables.length > 1)( failValidation( s"Variable '$name' is defined more than once.", "If any operation defines more than one variable with the same name, it is ambiguous and invalid. It is invalid even if the type of the duplicate variable is the same." ) ) - } *> ZPure.foreachDiscard(op.variableDefinitions) { v => + } *> validateAll(op.variableDefinitions) { v => val t = Type.innerType(v.variableType) ZPure.whenCase(context.rootType.types.get(t).map(_.kind)) { case Some(__TypeKind.OBJECT) | Some(__TypeKind.UNION) | Some(__TypeKind.INTERFACE) => @@ -371,14 +372,14 @@ object Validator { } } *> { val variableUsages = collectVariablesUsed(context, op.selectionSet) - ZPure.foreachDiscard(variableUsages)(v => + validateAll(variableUsages)(v => ZPure.when(!op.variableDefinitions.exists(_.name == v))( failValidation( s"Variable '$v' is not defined.", "Variables are scoped on a per‐operation basis. That means that any variable used within the context of an operation must be defined at the top level of that operation" ) ) - ) *> ZPure.foreachDiscard(op.variableDefinitions)(v => + ) *> validateAll(op.variableDefinitions)(v => ZPure.when(!variableUsages.contains(v.name))( failValidation( s"Variable '${v.name}' is not used.", @@ -397,19 +398,18 @@ object Validator { ZPure.serviceWithPure { context => val spreads = collectFragmentSpreads(context.selectionSets) val spreadNames = spreads.map(_.name).toSet - ZPure.foreachDiscard(context.fragments.values)(f => + validateAll(context.fragments.values)(f => if (!spreadNames.contains(f.name)) failValidation( s"Fragment '${f.name}' is not used in any spread.", "Defined fragments must be used within a document." ) - else - ZPure.when(detectCycles(context, f))( - failValidation( - s"Fragment '${f.name}' forms a cycle.", - "The graph of fragment spreads must not form any cycles including spreading itself. Otherwise an operation could infinitely spread or infinitely execute on cycles in the underlying data." - ) + else if (detectCycles(context, f)) + failValidation( + s"Fragment '${f.name}' forms a cycle.", + "The graph of fragment spreads must not form any cycles including spreading itself. Otherwise an operation could infinitely spread or infinitely execute on cycles in the underlying data." ) + else zunit ) } @@ -423,12 +423,11 @@ object Validator { } lazy val validateDocumentFields: QueryValidation = ZPure.serviceWithPure { context => - ZPure.foreachDiscard(context.document.definitions) { + validateAll(context.document.definitions) { case OperationDefinition(opType, _, _, _, selectionSet) => opType match { - case OperationType.Query => + case OperationType.Query => validateSelectionSet(context, selectionSet, context.rootType.queryType) - case OperationType.Mutation => context.rootType.mutationType.fold[EReader[Any, ValidationError, Unit]]( failValidation("Mutation operations are not supported on this schema.", "") @@ -438,9 +437,9 @@ object Validator { failValidation("Subscription operations are not supported on this schema.", "") )(validateSelectionSet(context, selectionSet, _)) } - case _: FragmentDefinition => ZPure.unit - case _: TypeSystemDefinition => ZPure.unit - case _: TypeSystemExtension => ZPure.unit + case _: FragmentDefinition => zunit + case _: TypeSystemDefinition => zunit + case _: TypeSystemExtension => zunit } } @@ -448,26 +447,28 @@ object Validator { context: Context, selectionSet: List[Selection], currentType: __Type - ): EReader[Any, ValidationError, Unit] = - validateFields(context, selectionSet, currentType) *> - FragmentValidator.findConflictsWithinSelectionSet(context, context.rootType.queryType, selectionSet) + ): EReader[Any, ValidationError, Unit] = { + val v1 = validateFields(context, selectionSet, currentType) + val v2 = FragmentValidator.findConflictsWithinSelectionSet(context, context.rootType.queryType, selectionSet) + v1 *> v2 + } private def validateFields( context: Context, selectionSet: List[Selection], currentType: __Type ): EReader[Any, ValidationError, Unit] = - ZPure.foreachDiscard(selectionSet) { + validateAll(selectionSet) { case f: Field => validateField(context, f, currentType) case FragmentSpread(name, _) => context.fragments.get(name) match { + case Some(fragment) => + validateSpread(context, Some(name), currentType, Some(fragment.typeCondition), fragment.selectionSet) case None => failValidation( s"Fragment spread '$name' is not defined.", "Named fragment spreads must refer to fragments defined within the document. It is a validation error if the target of a spread is not defined." ) - case Some(fragment) => - validateSpread(context, Some(name), currentType, Some(fragment.typeCondition), fragment.selectionSet) } case InlineFragment(typeCondition, _, selectionSet) => validateSpread(context, None, currentType, typeCondition, selectionSet) @@ -483,16 +484,16 @@ object Validator { typeCondition.fold[Option[__Type]](Some(currentType))(t => context.rootType.types.get(t.name)) match { case Some(fragmentType) => validateFragmentType(name, fragmentType) *> { - val possibleTypes = getPossibleTypeNames(currentType) - val possibleFragmentTypes = getPossibleTypeNames(fragmentType) + val possibleTypes = currentType.possibleTypeNames + val possibleFragmentTypes = fragmentType.possibleTypeNames val applicableTypes = possibleTypes intersect possibleFragmentTypes - ZPure.when(applicableTypes.isEmpty)( + if (applicableTypes.isEmpty) failValidation( s"${name.fold("Inline fragment spread")(n => s"Fragment spread '$n'")} is not possible: possible types are '${possibleTypes .mkString(", ")}' and possible fragment types are '${possibleFragmentTypes.mkString(", ")}'.", "Fragments are declared on a type and will only apply when the runtime object type matches the type condition. They also are spread within the context of a parent type. A fragment spread is only valid if its type condition could ever apply within the parent type." ) - ) *> validateFields(context, selectionSet, fragmentType) + else validateFields(context, selectionSet, fragmentType) } case None => lazy val typeConditionName = typeCondition.fold("?")(_.name) @@ -502,30 +503,19 @@ object Validator { ) } - private def getPossibleTypeNames(t: __Type): Set[String] = - t.kind match { - case __TypeKind.OBJECT => t.name.fold(Set.empty[String])(Set(_)) - case __TypeKind.INTERFACE | __TypeKind.UNION => t.possibleTypes.fold(Set.empty[String])(_.flatMap(_.name).toSet) - case _ => Set.empty - } - private def validateField(context: Context, field: Field, currentType: __Type): EReader[Any, ValidationError, Unit] = - ZPure - .when(field.name != "__typename") { - ZPure - .fromOption(currentType.allFieldsMap.get(field.name)) - .orElseFail( - ValidationError( - s"Field '${field.name}' does not exist on type '${DocumentRenderer.renderTypeName(currentType)}'.", - "The target field of a field selection must be defined on the scoped type of the selection set. There are no limitations on alias names." - ) + if (field.name != "__typename") { + currentType.allFieldsMap.get(field.name) match { + case Some(f) => + validateFields(context, field.selectionSet, f._type.innerType) *> + validateArguments(field, f, currentType, context) + case None => + failValidation( + s"Field '${field.name}' does not exist on type '${DocumentRenderer.renderTypeName(currentType)}'.", + "The target field of a field selection must be defined on the scoped type of the selection set. There are no limitations on alias names." ) - .flatMap { f => - validateFields(context, field.selectionSet, f._type.innerType) *> - validateArguments(field, f, currentType, context) - } } - .unit + } else zunit private def validateArguments( field: Field, @@ -533,7 +523,7 @@ object Validator { currentType: __Type, context: Context ): EReader[Any, ValidationError, Unit] = - ZPure.foreachDiscard(f.allArgs.filter(_._type.kind == __TypeKind.NON_NULL))(arg => + validateAll(f.allArgs.filter(_._type.kind == __TypeKind.NON_NULL))(arg => (arg.defaultValue, field.arguments.get(arg.name)) match { case (None, None) | (None, Some(NullValue)) => failValidation( @@ -541,32 +531,30 @@ object Validator { .getOrElse("")}'.", "Arguments can be required. An argument is required if the argument type is non‐null and does not have a default value. Otherwise, the argument is optional." ) - - case (Some(_), Some(NullValue)) => + case (Some(_), Some(NullValue)) => failValidation( s"Required argument '${arg.name}' is null on '${field.name}' of type '${currentType.name .getOrElse("")}'.", "Arguments can be required. An argument is required if the argument type is non‐null and does not have a default value. Otherwise, the argument is optional." ) - case _ => ZPure.unit[Unit] + case _ => zunit } - ) *> - ZPure.foreachDiscard(field.arguments) { case (arg, argValue) => - f.allArgs.find(_.name == arg) match { - case None => - failValidation( - s"Argument '$arg' is not defined on field '${field.name}' of type '${currentType.name.getOrElse("")}'.", - "Every argument provided to a field or directive must be defined in the set of possible arguments of that field or directive." - ) - case Some(inputValue) => - validateInputValues( - inputValue, - argValue, - context, - s"InputValue '${inputValue.name}' of Field '${field.name}'" - ) - } + ) *> validateAll(field.arguments) { case (arg, argValue) => + f.allArgs.find(_.name == arg) match { + case Some(inputValue) => + validateInputValues( + inputValue, + argValue, + context, + s"InputValue '${inputValue.name}' of Field '${field.name}'" + ) + case None => + failValidation( + s"Argument '$arg' is not defined on field '${field.name}' of type '${currentType.name.getOrElse("")}'.", + "Every argument provided to a field or directive must be defined in the set of possible arguments of that field or directive." + ) } + } private[caliban] def validateInputValues( inputValue: __InputValue, @@ -580,7 +568,7 @@ object Validator { argValue match { case InputValue.ObjectValue(fields) if inputType.kind == __TypeKind.INPUT_OBJECT => - ZPure.foreachDiscard(fields) { case (k, v) => + validateAll(fields) { case (k, v) => inputFields.find(_.name == k) match { case None => failValidation( @@ -595,7 +583,7 @@ object Validator { s"InputValue '${inputValue.name}' of Field '$k' of InputObject '${t.name.getOrElse("")}'" ) } - } *> ZPure.foreachDiscard(inputFields)(inputField => + } *> validateAll(inputFields)(inputField => ZPure.when( inputField.defaultValue.isEmpty && inputField._type.kind == __TypeKind.NON_NULL && @@ -616,7 +604,7 @@ object Validator { "Variables are scoped on a per‐operation basis. That means that any variable used within the context of an operation must be defined at the top level of that operation" ) } - case _ => ZPure.unit[Unit] + case _ => zunit } } *> ValueValidator.validateInputTypes(inputValue, argValue, context, errorContext) @@ -709,7 +697,7 @@ object Validator { s"Field selection is mandatory on type '${currentType.name.getOrElse("")}'.", "Leaf selections on objects, interfaces, and unions without subfields are disallowed." ) - case _ => ZPure.unit + case _ => zunit } lazy val validateOperationNameUniqueness: QueryValidation = ZPure.serviceWithPure { context => @@ -793,7 +781,7 @@ object Validator { private def validateFragmentType(name: Option[String], targetType: __Type): EReader[Any, ValidationError, Unit] = targetType.kind match { - case __TypeKind.UNION | __TypeKind.INTERFACE | __TypeKind.OBJECT => ZPure.unit + case __TypeKind.UNION | __TypeKind.INTERFACE | __TypeKind.OBJECT => zunit case _ => val targetTypeName = targetType.name.getOrElse("") failValidation( @@ -804,7 +792,7 @@ object Validator { private[caliban] def validateEnum(t: __Type): EReader[Any, ValidationError, Unit] = t.allEnumValues match { - case _ :: _ => ZPure.unit + case _ :: _ => zunit case Nil => failValidation( s"Enum ${t.name.getOrElse("")} doesn't contain any values", @@ -825,7 +813,7 @@ object Validator { types.filterNot(isObjectType).map(_.name.getOrElse("")).filterNot(_.isEmpty).mkString("", ", ", "."), s"The member types of a Union type must all be Object base types." ) - case _ => ZPure.unit + case _ => zunit } private[caliban] def validateInputObject(t: __Type): EReader[Any, ValidationError, Unit] = { @@ -842,7 +830,7 @@ object Validator { } def validateFields(fields: List[__InputValue]): EReader[Any, ValidationError, Unit] = - ZPure.foreachDiscard(fields)(validateInputValue(_, inputObjectContext)) *> + validateAll(fields)(validateInputValue(_, inputObjectContext)) *> noDuplicateInputValueName(fields, inputObjectContext) t.allInputFields match { @@ -924,11 +912,11 @@ object Validator { isNonNullableSubtype(supertypeFieldType, objectFieldType) } - ZPure.foreachDiscard(objectFields) { objField => + validateAll(objectFields) { objField => lazy val fieldContext = s"Field '${objField.name}'" supertypeFields.find(_.name == objField.name) match { - case None => ZPure.unit + case None => zunit case Some(superField) => val superArgs = superField.allArgs.map(arg => (arg.name, arg)).toMap val extraArgs = objField.allArgs.filter { arg => @@ -963,7 +951,7 @@ object Validator { s"$fieldContext with extra non-nullable arg(s) '$argNames' in $objectContext is invalid", "Any additional field arguments must not be of a non-nullable type." ) - case _ => ZPure.unit + case _ => zunit } } } @@ -1005,18 +993,18 @@ object Validator { s"${errorType.name.getOrElse("")} of $errorContext is of kind ${errorType.kind}, must be an InputType", """The input field must accept a type where IsInputType(type) returns true, https://spec.graphql.org/June2018/#IsInputType()""" ) - case Right(_) => ZPure.unit + case Right(_) => zunit } } private[caliban] def validateFields(fields: List[__Field], context: => String): EReader[Any, ValidationError, Unit] = noDuplicateFieldName(fields, context) <* - ZPure.foreachDiscard(fields) { field => + validateAll(fields) { field => lazy val fieldContext = s"Field '${field.name}' of $context" for { _ <- checkName(field.name, fieldContext) _ <- onlyOutputType(field._type, fieldContext) - _ <- ZPure.foreachDiscard(field.allArgs)(validateInputValue(_, fieldContext)) + _ <- validateAll(field.allArgs)(validateInputValue(_, fieldContext)) } yield () } @@ -1044,7 +1032,7 @@ object Validator { s"${errorType.name.getOrElse("")} of $errorContext is of kind ${errorType.kind}, must be an OutputType", """The input field must accept a type where IsOutputType(type) returns true, https://spec.graphql.org/June2018/#IsInputType()""" ) - case Right(_) => ZPure.unit + case Right(_) => zunit } } @@ -1057,7 +1045,7 @@ object Validator { listOfNamed .groupBy(nameExtractor(_)) .collectFirst { case (_, f :: _ :: _) => f } - .fold[EReader[Any, ValidationError, Unit]](ZPure.unit)(duplicate => + .fold[EReader[Any, ValidationError, Unit]](zunit)(duplicate => failValidation(messageBuilder(duplicate), explanatoryText) ) @@ -1110,7 +1098,7 @@ object Validator { "The mutation root operation is not an object type.", "The mutation root operation type is optional; if it is not provided, the service does not support mutations. If it is provided, it must be an Object type." ) - case _ => ZPure.unit + case _ => zunit } private[caliban] def validateRootSubscription[R](schema: RootSchemaBuilder[R]): EReader[Any, ValidationError, Unit] = @@ -1120,13 +1108,13 @@ object Validator { "The mutation root subscription is not an object type.", "The mutation root subscription type is optional; if it is not provided, the service does not support subscriptions. If it is provided, it must be an Object type." ) - case _ => ZPure.unit + case _ => zunit } private[caliban] def validateClashingTypes(types: List[__Type]): EReader[Any, ValidationError, Unit] = { val check = types.groupBy(_.name).collectFirst { case (Some(name), v) if v.size > 1 => (name, v) } check match { - case None => ZPure.unit + case None => zunit case Some((name, values)) => failValidation( s"Type '$name' is defined multiple times (${values @@ -1145,7 +1133,7 @@ object Validator { errorContext: => String ): EReader[Any, ValidationError, Unit] = { val argumentErrorContextBuilder = (name: String) => s"Argument '$name' of $errorContext" - ZPure.foreachDiscard(args.keys)(argName => checkName(argName, argumentErrorContextBuilder(argName))) + validateAll(args.keys)(argName => checkName(argName, argumentErrorContextBuilder(argName))) } def validateDirective(directive: Directive, errorContext: => String) = { @@ -1159,14 +1147,14 @@ object Validator { directives: Option[List[Directive]], errorContext: => String ): EReader[Any, ValidationError, Unit] = - ZPure.foreachDiscard(directives.getOrElse(List.empty))(validateDirective(_, errorContext)) + validateAll(directives.getOrElse(List.empty))(validateDirective(_, errorContext)) def validateInputValueDirectives( inputValues: List[__InputValue], errorContext: => String ): EReader[Any, ValidationError, Unit] = { val inputValueErrorContextBuilder = (name: String) => s"InputValue '$name' of $errorContext" - ZPure.foreachDiscard(inputValues)(iv => validateDirectives(iv.directives, inputValueErrorContextBuilder(iv.name))) + validateAll(inputValues)(iv => validateDirectives(iv.directives, inputValueErrorContextBuilder(iv.name))) } def validateFieldDirectives( @@ -1178,13 +1166,20 @@ object Validator { validateInputValueDirectives(field.allArgs, fieldErrorContext) } - ZPure.foreachDiscard(types) { t => + validateAll(types) { t => lazy val typeErrorContext = s"Type '${t.name.getOrElse("")}'" for { _ <- validateDirectives(t.directives, typeErrorContext) _ <- validateInputValueDirectives(t.allInputFields, typeErrorContext) - _ <- ZPure.foreachDiscard(t.allFields)(validateFieldDirectives(_, typeErrorContext)) + _ <- validateAll(t.allFields)(validateFieldDirectives(_, typeErrorContext)) } yield () } } + + // Pure's implementation doesn't check if the Iterable is empty, and this is causing some performance degradation + private def validateAll[R, A, B]( + in: Iterable[A] + )(f: A => EReader[R, ValidationError, B]): EReader[R, ValidationError, Unit] = + if (in.isEmpty) zunit + else ZPure.foreachDiscard(in)(f) } diff --git a/core/src/main/scala/caliban/validation/ValueValidator.scala b/core/src/main/scala/caliban/validation/ValueValidator.scala index debc17543a..26ee0d6d90 100644 --- a/core/src/main/scala/caliban/validation/ValueValidator.scala +++ b/core/src/main/scala/caliban/validation/ValueValidator.scala @@ -9,7 +9,6 @@ import caliban.parsing.Parser import caliban.{ InputValue, Value } import zio.prelude.EReader import zio.prelude.fx.ZPure -import zio.prelude._ object ValueValidator { def validateDefaultValue(field: __InputValue, errorContext: => String): EReader[Any, ValidationError, Unit] = @@ -68,7 +67,7 @@ object ValueValidator { case LIST => argValue match { case ListValue(values) => - values.forEach_(v => + ZPure.foreachDiscard(values)(v => validateType(inputType.ofType.getOrElse(inputType), v, context, s"List item in $errorContext") ) case NullValue => @@ -81,14 +80,14 @@ object ValueValidator { case INPUT_OBJECT => argValue match { case ObjectValue(fields) => - inputType.allInputFields.forEach_ { f => + ZPure.foreachDiscard(inputType.allInputFields) { f => fields.collectFirst { case (name, fieldValue) if name == f.name => fieldValue } match { - case Some(value) => + case Some(value) => validateType(f._type, value, context, s"Field ${f.name} in $errorContext") - case None => - ZPure.when(f.defaultValue.isEmpty) { - validateType(f._type, NullValue, context, s"Field ${f.name} in $errorContext") - } + case None if f.defaultValue.isEmpty => + validateType(f._type, NullValue, context, s"Field ${f.name} in $errorContext") + case _ => + ZPure.unit } } case NullValue => diff --git a/core/src/main/scala/caliban/validation/package.scala b/core/src/main/scala/caliban/validation/package.scala index b6789b85ee..a046b51889 100644 --- a/core/src/main/scala/caliban/validation/package.scala +++ b/core/src/main/scala/caliban/validation/package.scala @@ -12,7 +12,9 @@ package object validation { parentType: __Type, selection: Field, fieldDef: __Field - ) + ) { + final override lazy val hashCode: Int = super.hashCode() + } type FieldMap = Map[String, Set[SelectedField]]