Skip to content

Commit

Permalink
Avoid unnecessary .units in Validator (#2182)
Browse files Browse the repository at this point in the history
* Validator micro-optimizations

* Cleanup

* PR comment

* Remove unused type parameter
  • Loading branch information
kyri-petrou authored Apr 8, 2024
1 parent 5305eda commit 29a868b
Showing 1 changed file with 64 additions and 96 deletions.
160 changes: 64 additions & 96 deletions core/src/main/scala/caliban/validation/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ object Validator {
variables: Map[String, InputValue],
validations: List[QueryValidation]
): Either[ValidationError, Map[String, FragmentDefinition]] = {
val (operations, fragments, _, _) = collectDefinitions(document)
val (operations, fragments) = collectDefinitions(document)
validateFragments(fragments).flatMap { fragmentMap =>
val buf = ListBuffer.empty[Selection]
operations.foreach(op => collectSelectionSets(buf)(op.selectionSet))
Expand All @@ -182,25 +182,16 @@ object Validator {
}
}

private def collectDefinitions(
document: Document
): (List[OperationDefinition], List[FragmentDefinition], List[TypeSystemDefinition], List[TypeSystemExtension]) =
private def collectDefinitions(document: Document): (List[OperationDefinition], List[FragmentDefinition]) =
document.definitions.foldLeft(
(
List.empty[OperationDefinition],
List.empty[FragmentDefinition],
List.empty[TypeSystemDefinition],
List.empty[TypeSystemExtension]
List.empty[FragmentDefinition]
)
) {
case ((operations, fragments, types, extensions), o: OperationDefinition) =>
(o :: operations, fragments, types, extensions)
case ((operations, fragments, types, extensions), f: FragmentDefinition) =>
(operations, f :: fragments, types, extensions)
case ((operations, fragments, types, extensions), t: TypeSystemDefinition) =>
(operations, fragments, t :: types, extensions)
case ((operations, fragments, types, extensions), e: TypeSystemExtension) =>
(operations, fragments, types, e :: extensions)
case ((operations, fragments), o: OperationDefinition) => (o :: operations, fragments)
case ((operations, fragments), f: FragmentDefinition) => (operations, f :: fragments)
case (t, _) => t
}

private def collectVariablesUsed(context: Context, selectionSet: List[Selection]): mutable.Set[String] = {
Expand Down Expand Up @@ -361,11 +352,9 @@ object Validator {
)
}
} *>
ZPure.when(!directive.locations.contains(location))(
failValidation(
s"Directive '${d.name}' is used in invalid location '$location'.",
"GraphQL servers define what directives they support and where they support them. For each usage of a directive, the directive must be used in a location that the server has declared support for."
)
failWhen(!directive.locations.contains(location))(
s"Directive '${d.name}' is used in invalid location '$location'.",
"GraphQL servers define what directives they support and where they support them. For each usage of a directive, the directive must be used in a location that the server has declared support for."
)
}
}
Expand All @@ -381,35 +370,29 @@ object Validator {
if (variableDefinitions.isEmpty && variableUsages.isEmpty) zunit
else
validateAll(op.variableDefinitions.groupBy(_.name)) { case (name, variables) =>
ZPure.when(variables.sizeCompare(1) > 0)(
failValidation(
s"Variable '$name' is defined more than once.",
"If any operation defines more than one variable with the same name, it is ambiguous and invalid. It is invalid even if the type of the duplicate variable is the same."
)
failWhen(variables.sizeCompare(1) > 0)(
s"Variable '$name' is defined more than once.",
"If any operation defines more than one variable with the same name, it is ambiguous and invalid. It is invalid even if the type of the duplicate variable is the same."
)
} *> validateAll(op.variableDefinitions) { v =>
val t = Type.innerType(v.variableType)
ZPure.whenCase(context.rootType.types.get(t).map(_.kind)) {
case Some(__TypeKind.OBJECT) | Some(__TypeKind.UNION) | Some(__TypeKind.INTERFACE) =>
failValidation(
s"Type of variable '${v.name}' is not a valid input type.",
"Variables can only be input types. Objects, unions, and interfaces cannot be used as inputs."
)
}
failWhen(context.rootType.types.get(t).map(_.kind).exists {
case __TypeKind.OBJECT | __TypeKind.UNION | __TypeKind.INTERFACE => true
case _ => false
})(
s"Type of variable '${v.name}' is not a valid input type.",
"Variables can only be input types. Objects, unions, and interfaces cannot be used as inputs."
)
} *> {
validateAll(variableUsages)(v =>
ZPure.when(!op.variableDefinitions.exists(_.name == v))(
failValidation(
s"Variable '$v' is not defined.",
"Variables are scoped on a per‐operation basis. That means that any variable used within the context of an operation must be defined at the top level of that operation"
)
failWhen(!op.variableDefinitions.exists(_.name == v))(
s"Variable '$v' is not defined.",
"Variables are scoped on a per‐operation basis. That means that any variable used within the context of an operation must be defined at the top level of that operation"
)
) *> validateAll(op.variableDefinitions)(v =>
ZPure.when(!variableUsages.contains(v.name))(
failValidation(
s"Variable '${v.name}' is not used.",
"All variables defined by an operation must be used in that operation or a fragment transitively included by that operation. Unused variables cause a validation error."
)
failWhen(!variableUsages.contains(v.name))(
s"Variable '${v.name}' is not used.",
"All variables defined by an operation must be used in that operation or a fragment transitively included by that operation. Unused variables cause a validation error."
)
)
}
Expand Down Expand Up @@ -645,15 +628,13 @@ object Validator {
)
}
} *> validateAll(inputFields)(inputField =>
ZPure.when(
failWhen(
inputField.defaultValue.isEmpty &&
inputField._type.kind == __TypeKind.NON_NULL &&
fields.getOrElse(inputField.name, NullValue) == NullValue
)(
failValidation(
s"Required field '${inputField.name}' on object '${inputType.name.getOrElse("?")}' was not provided.",
"Input object fields may be required. Much like a field may have required arguments, an input object may have required fields. An input field is required if it has a non‐null type and does not have a default value. Otherwise, the input object field is optional."
)
s"Required field '${inputField.name}' on object '${inputType.name.getOrElse("?")}' was not provided.",
"Input object fields may be required. Much like a field may have required arguments, an input object may have required fields. An input field is required if it has a non‐null type and does not have a default value. Otherwise, the input object field is optional."
)
)
case VariableValue(variableName) =>
Expand Down Expand Up @@ -731,15 +712,11 @@ object Validator {
explanation
)
case Type.NamedType(name, _) =>
ZPure
.when(!locationType.name.contains(name))(
failValidation(
s"Variable '$variableName' usage is not allowed because its type doesn't match the schema ($name instead of ${locationType.name
.getOrElse("")}).",
explanation
)
)
.unit
failWhen(!locationType.name.contains(name))(
s"Variable '$variableName' usage is not allowed because its type doesn't match the schema ($name instead of ${locationType.name
.getOrElse("")}).",
explanation
)
}
}

Expand Down Expand Up @@ -770,28 +747,20 @@ object Validator {
val operations = context.operations
val names = operations.flatMap(_.name).groupBy(identity)
val repeatedNames = names.collect { case (name, items) if items.length > 1 => name }
ZPure
.when(repeatedNames.nonEmpty)(
failValidation(
s"Multiple operations have the same name: ${repeatedNames.mkString(", ")}.",
"Each named operation definition must be unique within a document when referred to by its name."
)
)
.unit
failWhen(repeatedNames.nonEmpty)(
s"Multiple operations have the same name: ${repeatedNames.mkString(", ")}.",
"Each named operation definition must be unique within a document when referred to by its name."
)
}

lazy val validateLoneAnonymousOperation: QueryValidation = ZPure.environmentWithPure { env =>
val context = env.get[Context]
val operations = context.operations
val anonymous = operations.filter(_.name.isEmpty)
ZPure
.when(operations.length > 1 && anonymous.nonEmpty)(
failValidation(
"Found both anonymous and named operations.",
"GraphQL allows a short‐hand form for defining query operations when only that one operation exists in the document."
)
)
.unit
failWhen(operations.length > 1 && anonymous.nonEmpty)(
"Found both anonymous and named operations.",
"GraphQL allows a short‐hand form for defining query operations when only that one operation exists in the document."
)
}

private def validateFragments(
Expand Down Expand Up @@ -852,7 +821,7 @@ object Validator {
}
} yield error
}
ZPure.fromOption(error).flip.unit
ZPure.fromOption(error).flip
}

private def validateFragmentType(
Expand Down Expand Up @@ -962,15 +931,13 @@ object Validator {
val interfaceFieldNames = supertype.map(fieldNames).toSet.flatten
val isMissingFields = objectFieldNames.union(interfaceFieldNames) != objectFieldNames

ZPure
.when(interfaceFieldNames.nonEmpty && isMissingFields) {
failWhen(interfaceFieldNames.nonEmpty && isMissingFields)(
{
val missingFields = interfaceFieldNames.diff(objectFieldNames).toList.sorted
failValidation(
s"$objectContext is missing field(s): ${missingFields.mkString(", ")}",
"An Object type must include a field of the same name for every field defined in an interface"
)
}
.unit
s"$objectContext is missing field(s): ${missingFields.mkString(", ")}"
},
"An Object type must include a field of the same name for every field defined in an interface"
)
}

def checkForInvalidSubtypeFields(): EReader[Any, ValidationError, Unit] = {
Expand Down Expand Up @@ -1132,7 +1099,7 @@ object Validator {

private[caliban] def checkName(name: String, fieldContext: => String): EReader[Any, ValidationError, Unit] =
ZPure
.fromEither(Parser.parseName(name).unit)
.fromEither(Parser.parseName(name))
.mapError(e =>
ValidationError(
s"$fieldContext is not a valid name.",
Expand All @@ -1144,14 +1111,10 @@ object Validator {
name: String,
errorContext: => String
): EReader[Any, ValidationError, Unit] =
ZPure
.when(name.startsWith("__"))(
failValidation(
s"$errorContext can't start with '__'",
"""Names can not begin with the characters "__" (two underscores)"""
)
)
.unit
failWhen(name.startsWith("__"))(
s"$errorContext can't start with '__'",
"""Names can not begin with the characters "__" (two underscores)"""
)

private[caliban] def validateRootQuery[R](
schema: RootSchemaBuilder[R]
Expand Down Expand Up @@ -1260,24 +1223,29 @@ object Validator {
/**
* Wrapper around `ZPure.foreachDiscard` optimized for cases where the input is empty or has only one element.
*/
private def validateAll[R, A, B](
private def validateAll[R, A](
in: Iterable[A]
)(f: A => EReader[R, ValidationError, B]): EReader[R, ValidationError, Unit] =
)(f: A => EReader[R, ValidationError, Unit]): EReader[R, ValidationError, Unit] =
in.sizeCompare(1) match {
case -1 => zunit
case 0 => f(in.head).unit
case 0 => f(in.head)
case _ => ZPure.foreachDiscard(in)(f)
}

private def validateAllNonEmpty[A, B](
private def validateAllNonEmpty[A](
in: Iterable[A]
)(f: A => EReader[Any, ValidationError, B]): OptionalValidation =
)(f: A => EReader[Any, ValidationError, Unit]): OptionalValidation =
in.sizeCompare(1) match {
case -1 => None
case 0 => Some(f(in.head).unit)
case 0 => Some(f(in.head))
case _ => Some(ZPure.foreachDiscard(in)(f))
}

private def failWhen(
condition: Boolean
)(msg: => String, explanatoryText: => String): EReader[Any, ValidationError, Unit] =
if (condition) failValidation(msg, explanatoryText) else zunit

private implicit class EnrichedListBufferOps[A](private val lb: ListBuffer[A]) extends AnyVal {
// This method doesn't exist in Scala 2.12 so we just use `.map` for it instead
def addOne(elem: A): ListBuffer[A] = lb += elem
Expand Down

0 comments on commit 29a868b

Please sign in to comment.