Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unnecessary .units in Validator #2182

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 64 additions & 96 deletions core/src/main/scala/caliban/validation/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ object Validator {
variables: Map[String, InputValue],
validations: List[QueryValidation]
): Either[ValidationError, Map[String, FragmentDefinition]] = {
val (operations, fragments, _, _) = collectDefinitions(document)
val (operations, fragments) = collectDefinitions(document)
validateFragments(fragments).flatMap { fragmentMap =>
val buf = ListBuffer.empty[Selection]
operations.foreach(op => collectSelectionSets(buf)(op.selectionSet))
Expand All @@ -182,25 +182,16 @@ object Validator {
}
}

private def collectDefinitions(
document: Document
): (List[OperationDefinition], List[FragmentDefinition], List[TypeSystemDefinition], List[TypeSystemExtension]) =
private def collectDefinitions(document: Document): (List[OperationDefinition], List[FragmentDefinition]) =
document.definitions.foldLeft(
(
List.empty[OperationDefinition],
List.empty[FragmentDefinition],
List.empty[TypeSystemDefinition],
List.empty[TypeSystemExtension]
List.empty[FragmentDefinition]
)
) {
case ((operations, fragments, types, extensions), o: OperationDefinition) =>
(o :: operations, fragments, types, extensions)
case ((operations, fragments, types, extensions), f: FragmentDefinition) =>
(operations, f :: fragments, types, extensions)
case ((operations, fragments, types, extensions), t: TypeSystemDefinition) =>
(operations, fragments, t :: types, extensions)
case ((operations, fragments, types, extensions), e: TypeSystemExtension) =>
(operations, fragments, types, e :: extensions)
case ((operations, fragments), o: OperationDefinition) => (o :: operations, fragments)
case ((operations, fragments), f: FragmentDefinition) => (operations, f :: fragments)
case (t, _) => t
}

private def collectVariablesUsed(context: Context, selectionSet: List[Selection]): mutable.Set[String] = {
Expand Down Expand Up @@ -361,11 +352,9 @@ object Validator {
)
}
} *>
ZPure.when(!directive.locations.contains(location))(
failValidation(
s"Directive '${d.name}' is used in invalid location '$location'.",
"GraphQL servers define what directives they support and where they support them. For each usage of a directive, the directive must be used in a location that the server has declared support for."
)
failWhen(!directive.locations.contains(location))(
s"Directive '${d.name}' is used in invalid location '$location'.",
"GraphQL servers define what directives they support and where they support them. For each usage of a directive, the directive must be used in a location that the server has declared support for."
)
}
}
Expand All @@ -381,35 +370,29 @@ object Validator {
if (variableDefinitions.isEmpty && variableUsages.isEmpty) zunit
else
validateAll(op.variableDefinitions.groupBy(_.name)) { case (name, variables) =>
ZPure.when(variables.sizeCompare(1) > 0)(
failValidation(
s"Variable '$name' is defined more than once.",
"If any operation defines more than one variable with the same name, it is ambiguous and invalid. It is invalid even if the type of the duplicate variable is the same."
)
failWhen(variables.sizeCompare(1) > 0)(
s"Variable '$name' is defined more than once.",
"If any operation defines more than one variable with the same name, it is ambiguous and invalid. It is invalid even if the type of the duplicate variable is the same."
)
} *> validateAll(op.variableDefinitions) { v =>
val t = Type.innerType(v.variableType)
ZPure.whenCase(context.rootType.types.get(t).map(_.kind)) {
case Some(__TypeKind.OBJECT) | Some(__TypeKind.UNION) | Some(__TypeKind.INTERFACE) =>
failValidation(
s"Type of variable '${v.name}' is not a valid input type.",
"Variables can only be input types. Objects, unions, and interfaces cannot be used as inputs."
)
}
failWhen(context.rootType.types.get(t).map(_.kind).exists {
case __TypeKind.OBJECT | __TypeKind.UNION | __TypeKind.INTERFACE => true
case _ => false
})(
s"Type of variable '${v.name}' is not a valid input type.",
"Variables can only be input types. Objects, unions, and interfaces cannot be used as inputs."
)
} *> {
validateAll(variableUsages)(v =>
ZPure.when(!op.variableDefinitions.exists(_.name == v))(
failValidation(
s"Variable '$v' is not defined.",
"Variables are scoped on a per‐operation basis. That means that any variable used within the context of an operation must be defined at the top level of that operation"
)
failWhen(!op.variableDefinitions.exists(_.name == v))(
s"Variable '$v' is not defined.",
"Variables are scoped on a per‐operation basis. That means that any variable used within the context of an operation must be defined at the top level of that operation"
)
) *> validateAll(op.variableDefinitions)(v =>
ZPure.when(!variableUsages.contains(v.name))(
failValidation(
s"Variable '${v.name}' is not used.",
"All variables defined by an operation must be used in that operation or a fragment transitively included by that operation. Unused variables cause a validation error."
)
failWhen(!variableUsages.contains(v.name))(
s"Variable '${v.name}' is not used.",
"All variables defined by an operation must be used in that operation or a fragment transitively included by that operation. Unused variables cause a validation error."
)
)
}
Expand Down Expand Up @@ -645,15 +628,13 @@ object Validator {
)
}
} *> validateAll(inputFields)(inputField =>
ZPure.when(
failWhen(
inputField.defaultValue.isEmpty &&
inputField._type.kind == __TypeKind.NON_NULL &&
fields.getOrElse(inputField.name, NullValue) == NullValue
)(
failValidation(
s"Required field '${inputField.name}' on object '${inputType.name.getOrElse("?")}' was not provided.",
"Input object fields may be required. Much like a field may have required arguments, an input object may have required fields. An input field is required if it has a non‐null type and does not have a default value. Otherwise, the input object field is optional."
)
s"Required field '${inputField.name}' on object '${inputType.name.getOrElse("?")}' was not provided.",
"Input object fields may be required. Much like a field may have required arguments, an input object may have required fields. An input field is required if it has a non‐null type and does not have a default value. Otherwise, the input object field is optional."
)
)
case VariableValue(variableName) =>
Expand Down Expand Up @@ -731,15 +712,11 @@ object Validator {
explanation
)
case Type.NamedType(name, _) =>
ZPure
.when(!locationType.name.contains(name))(
failValidation(
s"Variable '$variableName' usage is not allowed because its type doesn't match the schema ($name instead of ${locationType.name
.getOrElse("")}).",
explanation
)
)
.unit
failWhen(!locationType.name.contains(name))(
s"Variable '$variableName' usage is not allowed because its type doesn't match the schema ($name instead of ${locationType.name
.getOrElse("")}).",
explanation
)
}
}

Expand Down Expand Up @@ -770,28 +747,20 @@ object Validator {
val operations = context.operations
val names = operations.flatMap(_.name).groupBy(identity)
val repeatedNames = names.collect { case (name, items) if items.length > 1 => name }
ZPure
.when(repeatedNames.nonEmpty)(
failValidation(
s"Multiple operations have the same name: ${repeatedNames.mkString(", ")}.",
"Each named operation definition must be unique within a document when referred to by its name."
)
)
.unit
failWhen(repeatedNames.nonEmpty)(
s"Multiple operations have the same name: ${repeatedNames.mkString(", ")}.",
"Each named operation definition must be unique within a document when referred to by its name."
)
}

lazy val validateLoneAnonymousOperation: QueryValidation = ZPure.environmentWithPure { env =>
val context = env.get[Context]
val operations = context.operations
val anonymous = operations.filter(_.name.isEmpty)
ZPure
.when(operations.length > 1 && anonymous.nonEmpty)(
failValidation(
"Found both anonymous and named operations.",
"GraphQL allows a short‐hand form for defining query operations when only that one operation exists in the document."
)
)
.unit
failWhen(operations.length > 1 && anonymous.nonEmpty)(
"Found both anonymous and named operations.",
"GraphQL allows a short‐hand form for defining query operations when only that one operation exists in the document."
)
}

private def validateFragments(
Expand Down Expand Up @@ -852,7 +821,7 @@ object Validator {
}
} yield error
}
ZPure.fromOption(error).flip.unit
ZPure.fromOption(error).flip
}

private def validateFragmentType(
Expand Down Expand Up @@ -962,15 +931,13 @@ object Validator {
val interfaceFieldNames = supertype.map(fieldNames).toSet.flatten
val isMissingFields = objectFieldNames.union(interfaceFieldNames) != objectFieldNames

ZPure
.when(interfaceFieldNames.nonEmpty && isMissingFields) {
failWhen(interfaceFieldNames.nonEmpty && isMissingFields)(
{
val missingFields = interfaceFieldNames.diff(objectFieldNames).toList.sorted
failValidation(
s"$objectContext is missing field(s): ${missingFields.mkString(", ")}",
"An Object type must include a field of the same name for every field defined in an interface"
)
}
.unit
s"$objectContext is missing field(s): ${missingFields.mkString(", ")}"
},
"An Object type must include a field of the same name for every field defined in an interface"
)
}

def checkForInvalidSubtypeFields(): EReader[Any, ValidationError, Unit] = {
Expand Down Expand Up @@ -1132,7 +1099,7 @@ object Validator {

private[caliban] def checkName(name: String, fieldContext: => String): EReader[Any, ValidationError, Unit] =
ZPure
.fromEither(Parser.parseName(name).unit)
.fromEither(Parser.parseName(name))
.mapError(e =>
ValidationError(
s"$fieldContext is not a valid name.",
Expand All @@ -1144,14 +1111,10 @@ object Validator {
name: String,
errorContext: => String
): EReader[Any, ValidationError, Unit] =
ZPure
.when(name.startsWith("__"))(
failValidation(
s"$errorContext can't start with '__'",
"""Names can not begin with the characters "__" (two underscores)"""
)
)
.unit
failWhen(name.startsWith("__"))(
s"$errorContext can't start with '__'",
"""Names can not begin with the characters "__" (two underscores)"""
)

private[caliban] def validateRootQuery[R](
schema: RootSchemaBuilder[R]
Expand Down Expand Up @@ -1260,24 +1223,29 @@ object Validator {
/**
* Wrapper around `ZPure.foreachDiscard` optimized for cases where the input is empty or has only one element.
*/
private def validateAll[R, A, B](
private def validateAll[R, A](
in: Iterable[A]
)(f: A => EReader[R, ValidationError, B]): EReader[R, ValidationError, Unit] =
)(f: A => EReader[R, ValidationError, Unit]): EReader[R, ValidationError, Unit] =
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove the B type param

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops 😅 Done

in.sizeCompare(1) match {
case -1 => zunit
case 0 => f(in.head).unit
case 0 => f(in.head)
case _ => ZPure.foreachDiscard(in)(f)
}

private def validateAllNonEmpty[A, B](
private def validateAllNonEmpty[A](
in: Iterable[A]
)(f: A => EReader[Any, ValidationError, B]): OptionalValidation =
)(f: A => EReader[Any, ValidationError, Unit]): OptionalValidation =
in.sizeCompare(1) match {
case -1 => None
case 0 => Some(f(in.head).unit)
case 0 => Some(f(in.head))
case _ => Some(ZPure.foreachDiscard(in)(f))
}

private def failWhen(
condition: Boolean
)(msg: => String, explanatoryText: => String): EReader[Any, ValidationError, Unit] =
if (condition) failValidation(msg, explanatoryText) else zunit

private implicit class EnrichedListBufferOps[A](private val lb: ListBuffer[A]) extends AnyVal {
// This method doesn't exist in Scala 2.12 so we just use `.map` for it instead
def addOne(elem: A): ListBuffer[A] = lb += elem
Expand Down