Skip to content

Commit

Permalink
Avoid repeating validations on Fragments (#2315)
Browse files Browse the repository at this point in the history
* Cache different validations of fragments

* fmt
  • Loading branch information
kyri-petrou authored Jun 27, 2024
1 parent 806a8d0 commit 82ce42d
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions core/src/main/scala/caliban/validation/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ object Validator {
private def collectVariablesUsed(context: Context, selectionSet: List[Selection]): mutable.Set[String] = {
val allValues = ListBuffer.empty[InputValue]
val variables = mutable.Set.empty[String]
val seen = mutable.HashSet.empty[String]

def collectValues(selectionSet: List[Selection]): Unit = {
// ugly mutable code but it's worth it for the speed ;)
Expand All @@ -193,22 +194,28 @@ object Validator {
()
}

def collectDirectives(d: List[Directive]) =
if (d.nonEmpty) d.foreach(d => add(d.arguments))

selectionSet.foreach {
case Field(_, _, arguments, directives, selectionSet, _) =>
add(arguments)
directives.foreach(d => add(d.arguments))
collectValues(selectionSet)
collectDirectives(directives)
if (selectionSet.nonEmpty) collectValues(selectionSet)
case FragmentSpread(name, directives) =>
directives.foreach(d => add(d.arguments))
context.fragments
.get(name)
.foreach { f =>
f.directives.foreach(d => add(d.arguments))
collectValues(f.selectionSet)
}
if (seen.add(name)) {
collectDirectives(directives)
context.fragments
.get(name)
.foreach { f =>
collectDirectives(f.directives)
val set = f.selectionSet
if (set.nonEmpty) collectValues(set)
}
}
case InlineFragment(_, directives, selectionSet) =>
directives.foreach(d => add(d.arguments))
collectValues(selectionSet)
collectDirectives(directives)
if (selectionSet.nonEmpty) collectValues(selectionSet)
}
}

Expand Down Expand Up @@ -471,12 +478,14 @@ object Validator {
case _: Selection.FragmentSpread => true
}

private type ValidatedFragments = mutable.HashSet[(String, Option[String])]

private def validateSelectionSet(
context: Context,
selectionSet: List[Selection],
currentType: __Type
): Either[ValidationError, Unit] = {
val v1 = validateFields(context, selectionSet, currentType)
val v1 = validateFields(context, selectionSet, currentType)(mutable.HashSet.empty)
if (context.fragments.nonEmpty || containsFragments(selectionSet))
v1 *> FragmentValidator.findConflictsWithinSelectionSet(context, context.rootType.queryType, selectionSet)
else v1
Expand All @@ -486,19 +495,21 @@ object Validator {
context: Context,
selectionSet: List[Selection],
currentType: __Type
): Either[ValidationError, Unit] = {
)(implicit checked: ValidatedFragments): Either[ValidationError, Unit] = {
val v1 = validateAllDiscard(selectionSet) {
case f: Field =>
validateField(context, f, currentType)
case FragmentSpread(name, _) =>
context.fragments.getOrElse(name, null) match {
case null =>
case null =>
failValidation(
s"Fragment spread '$name' is not defined.",
"Named fragment spreads must refer to fragments defined within the document. It is a validation error if the target of a spread is not defined."
)
case fragment =>
case fragment if checked.add((name, currentType.name)) =>
validateSpread(context, Some(name), currentType, Some(fragment.typeCondition), fragment.selectionSet)
case _ =>
unit
}
case InlineFragment(typeCondition, _, selectionSet) =>
validateSpread(context, None, currentType, typeCondition, selectionSet)
Expand All @@ -513,7 +524,7 @@ object Validator {
currentType: __Type,
typeCondition: Option[NamedType],
selectionSet: List[Selection]
): Either[ValidationError, Unit] =
)(implicit v: ValidatedFragments): Either[ValidationError, Unit] =
typeCondition.fold(currentType)(t => context.rootType.types.getOrElse(t.name, null)) match {
case null =>
val typeConditionName = typeCondition.fold("?")(_.name)
Expand Down Expand Up @@ -546,7 +557,11 @@ object Validator {
case (Some(v1), Some(v2)) => Some(v1 *> v2)
}

private def validateField(context: Context, field: Field, currentType: __Type): Either[ValidationError, Unit] =
private def validateField(
context: Context,
field: Field,
currentType: __Type
)(implicit v: ValidatedFragments): Either[ValidationError, Unit] =
if (field.name != "__typename") {
currentType.getFieldOrNull(field.name) match {
case null =>
Expand Down

0 comments on commit 82ce42d

Please sign in to comment.