diff --git a/core/src/main/scala/caliban/validation/Validator.scala b/core/src/main/scala/caliban/validation/Validator.scala index cd63980e9..ac6718d02 100644 --- a/core/src/main/scala/caliban/validation/Validator.scala +++ b/core/src/main/scala/caliban/validation/Validator.scala @@ -185,6 +185,7 @@ object Validator { private def collectVariablesUsed(context: Context, selectionSet: List[Selection]): mutable.Set[String] = { val allValues = ListBuffer.empty[InputValue] val variables = mutable.Set.empty[String] + val seen = mutable.HashSet.empty[String] def collectValues(selectionSet: List[Selection]): Unit = { // ugly mutable code but it's worth it for the speed ;) @@ -193,22 +194,28 @@ object Validator { () } + def collectDirectives(d: List[Directive]) = + if (d.nonEmpty) d.foreach(d => add(d.arguments)) + selectionSet.foreach { case Field(_, _, arguments, directives, selectionSet, _) => add(arguments) - directives.foreach(d => add(d.arguments)) - collectValues(selectionSet) + collectDirectives(directives) + if (selectionSet.nonEmpty) collectValues(selectionSet) case FragmentSpread(name, directives) => - directives.foreach(d => add(d.arguments)) - context.fragments - .get(name) - .foreach { f => - f.directives.foreach(d => add(d.arguments)) - collectValues(f.selectionSet) - } + if (seen.add(name)) { + collectDirectives(directives) + context.fragments + .get(name) + .foreach { f => + collectDirectives(f.directives) + val set = f.selectionSet + if (set.nonEmpty) collectValues(set) + } + } case InlineFragment(_, directives, selectionSet) => - directives.foreach(d => add(d.arguments)) - collectValues(selectionSet) + collectDirectives(directives) + if (selectionSet.nonEmpty) collectValues(selectionSet) } } @@ -471,12 +478,14 @@ object Validator { case _: Selection.FragmentSpread => true } + private type ValidatedFragments = mutable.HashSet[(String, Option[String])] + private def validateSelectionSet( context: Context, selectionSet: List[Selection], currentType: __Type ): Either[ValidationError, Unit] = { - val v1 = validateFields(context, selectionSet, currentType) + val v1 = validateFields(context, selectionSet, currentType)(mutable.HashSet.empty) if (context.fragments.nonEmpty || containsFragments(selectionSet)) v1 *> FragmentValidator.findConflictsWithinSelectionSet(context, context.rootType.queryType, selectionSet) else v1 @@ -486,19 +495,21 @@ object Validator { context: Context, selectionSet: List[Selection], currentType: __Type - ): Either[ValidationError, Unit] = { + )(implicit checked: ValidatedFragments): Either[ValidationError, Unit] = { val v1 = validateAllDiscard(selectionSet) { case f: Field => validateField(context, f, currentType) case FragmentSpread(name, _) => context.fragments.getOrElse(name, null) match { - case null => + case null => failValidation( s"Fragment spread '$name' is not defined.", "Named fragment spreads must refer to fragments defined within the document. It is a validation error if the target of a spread is not defined." ) - case fragment => + case fragment if checked.add((name, currentType.name)) => validateSpread(context, Some(name), currentType, Some(fragment.typeCondition), fragment.selectionSet) + case _ => + unit } case InlineFragment(typeCondition, _, selectionSet) => validateSpread(context, None, currentType, typeCondition, selectionSet) @@ -513,7 +524,7 @@ object Validator { currentType: __Type, typeCondition: Option[NamedType], selectionSet: List[Selection] - ): Either[ValidationError, Unit] = + )(implicit v: ValidatedFragments): Either[ValidationError, Unit] = typeCondition.fold(currentType)(t => context.rootType.types.getOrElse(t.name, null)) match { case null => val typeConditionName = typeCondition.fold("?")(_.name) @@ -546,7 +557,11 @@ object Validator { case (Some(v1), Some(v2)) => Some(v1 *> v2) } - private def validateField(context: Context, field: Field, currentType: __Type): Either[ValidationError, Unit] = + private def validateField( + context: Context, + field: Field, + currentType: __Type + )(implicit v: ValidatedFragments): Either[ValidationError, Unit] = if (field.name != "__typename") { currentType.getFieldOrNull(field.name) match { case null =>