Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid repeating validations on Fragments #2315

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions core/src/main/scala/caliban/validation/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ object Validator {
private def collectVariablesUsed(context: Context, selectionSet: List[Selection]): mutable.Set[String] = {
val allValues = ListBuffer.empty[InputValue]
val variables = mutable.Set.empty[String]
val seen = mutable.HashSet.empty[String]

def collectValues(selectionSet: List[Selection]): Unit = {
// ugly mutable code but it's worth it for the speed ;)
Expand All @@ -193,22 +194,28 @@ object Validator {
()
}

def collectDirectives(d: List[Directive]) =
if (d.nonEmpty) d.foreach(d => add(d.arguments))

selectionSet.foreach {
case Field(_, _, arguments, directives, selectionSet, _) =>
add(arguments)
directives.foreach(d => add(d.arguments))
collectValues(selectionSet)
collectDirectives(directives)
if (selectionSet.nonEmpty) collectValues(selectionSet)
case FragmentSpread(name, directives) =>
directives.foreach(d => add(d.arguments))
context.fragments
.get(name)
.foreach { f =>
f.directives.foreach(d => add(d.arguments))
collectValues(f.selectionSet)
}
if (seen.add(name)) {
collectDirectives(directives)
context.fragments
.get(name)
.foreach { f =>
collectDirectives(f.directives)
val set = f.selectionSet
if (set.nonEmpty) collectValues(set)
}
}
case InlineFragment(_, directives, selectionSet) =>
directives.foreach(d => add(d.arguments))
collectValues(selectionSet)
collectDirectives(directives)
if (selectionSet.nonEmpty) collectValues(selectionSet)
}
}

Expand Down Expand Up @@ -471,12 +478,14 @@ object Validator {
case _: Selection.FragmentSpread => true
}

private type ValidatedFragments = mutable.HashSet[(String, Option[String])]

private def validateSelectionSet(
context: Context,
selectionSet: List[Selection],
currentType: __Type
): Either[ValidationError, Unit] = {
val v1 = validateFields(context, selectionSet, currentType)
val v1 = validateFields(context, selectionSet, currentType)(mutable.HashSet.empty)
if (context.fragments.nonEmpty || containsFragments(selectionSet))
v1 *> FragmentValidator.findConflictsWithinSelectionSet(context, context.rootType.queryType, selectionSet)
else v1
Expand All @@ -486,19 +495,21 @@ object Validator {
context: Context,
selectionSet: List[Selection],
currentType: __Type
): Either[ValidationError, Unit] = {
)(implicit checked: ValidatedFragments): Either[ValidationError, Unit] = {
val v1 = validateAllDiscard(selectionSet) {
case f: Field =>
validateField(context, f, currentType)
case FragmentSpread(name, _) =>
context.fragments.getOrElse(name, null) match {
case null =>
case null =>
failValidation(
s"Fragment spread '$name' is not defined.",
"Named fragment spreads must refer to fragments defined within the document. It is a validation error if the target of a spread is not defined."
)
case fragment =>
case fragment if checked.add((name, currentType.name)) =>
validateSpread(context, Some(name), currentType, Some(fragment.typeCondition), fragment.selectionSet)
case _ =>
unit
}
case InlineFragment(typeCondition, _, selectionSet) =>
validateSpread(context, None, currentType, typeCondition, selectionSet)
Expand All @@ -513,7 +524,7 @@ object Validator {
currentType: __Type,
typeCondition: Option[NamedType],
selectionSet: List[Selection]
): Either[ValidationError, Unit] =
)(implicit v: ValidatedFragments): Either[ValidationError, Unit] =
typeCondition.fold(currentType)(t => context.rootType.types.getOrElse(t.name, null)) match {
case null =>
val typeConditionName = typeCondition.fold("?")(_.name)
Expand Down Expand Up @@ -546,7 +557,11 @@ object Validator {
case (Some(v1), Some(v2)) => Some(v1 *> v2)
}

private def validateField(context: Context, field: Field, currentType: __Type): Either[ValidationError, Unit] =
private def validateField(
context: Context,
field: Field,
currentType: __Type
)(implicit v: ValidatedFragments): Either[ValidationError, Unit] =
if (field.name != "__typename") {
currentType.getFieldOrNull(field.name) match {
case null =>
Expand Down