Skip to content

Commit

Permalink
feat: Support @GQLDefault (#1043)
Browse files Browse the repository at this point in the history
* feat: Add @GQLDefault annotation

* fix: tudy up pattern match

* fix: improve error messages

* style: group imports

* style: cleanup imports

* style: cleanup imports

* style: cleanup imports

* fix: clearer test descriptions

* fix: remove unused import

* style: use whenCase

* style: use IO.unless

* fix: Rename DefaultValue -> DefaultValueValidator

* doc: Document @GQLDefault
  • Loading branch information
frekw authored Sep 15, 2021
1 parent 0c40977 commit 54eba28
Show file tree
Hide file tree
Showing 15 changed files with 394 additions and 21 deletions.
13 changes: 13 additions & 0 deletions core/src/main/scala-2/caliban/parsing/Parser.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package caliban.parsing

import caliban.CalibanError.ParsingError
import caliban.InputValue
import caliban.parsing.adt._
import fastparse._
import zio.{ IO, Task }

import scala.util.Try

object Parser {
import caliban.parsing.parsers.Parsers._

Expand All @@ -21,6 +24,16 @@ object Parser {
}
}

def parseInputValue(rawValue: String): Either[ParsingError, InputValue] = {
val sm = SourceMapper(rawValue)
Try(parse(rawValue, value(_))).toEither.left
.map(ex => ParsingError(s"Internal parsing error", innerThrowable = Some(ex)))
.flatMap {
case Parsed.Success(value, _) => Right(value)
case f: Parsed.Failure => Left(ParsingError(f.msg, Some(sm.getLocation(f.index))))
}
}

/**
* Checks if the query is valid, if not returns an error string.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package caliban.schema
import caliban.CalibanError.ExecutionError
import caliban.InputValue
import caliban.Value._
import caliban.parsing.Parser
import caliban.schema.Annotations.GQLDefault
import caliban.schema.Annotations.GQLName
import magnolia._
import mercator.Monadic
Expand Down Expand Up @@ -30,8 +32,9 @@ trait ArgBuilderDerivation {
ctx.constructMonadic { p =>
input match {
case InputValue.ObjectValue(fields) =>
val label = p.annotations.collectFirst { case GQLName(name) => name }.getOrElse(p.label)
fields.get(label).fold(p.typeclass.buildMissing)(p.typeclass.build)
val label = p.annotations.collectFirst { case GQLName(name) => name }.getOrElse(p.label)
val default = p.annotations.collectFirst { case GQLDefault(v) => v }
fields.get(label).fold(p.typeclass.buildMissing(default))(p.typeclass.build)
case value => p.typeclass.build(value)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ trait SchemaDerivation[R] extends LowPriorityDerivedSchema {
() =>
if (p.typeclass.optional) p.typeclass.toType_(isInput, isSubscription)
else makeNonNull(p.typeclass.toType_(isInput, isSubscription)),
None,
p.annotations.collectFirst { case GQLDefault(v) => v },
Some(p.annotations.collect { case GQLDirective(dir) => dir }.toList).filter(_.nonEmpty)
)
)
Expand Down
18 changes: 18 additions & 0 deletions core/src/main/scala-3/caliban/parsing/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import caliban.parsing.adt._
import cats.parse.{ Numbers, Parser => P }
import cats.parse._
import zio.{ IO, Task }
import scala.util.Try

object Parser {
private final val UnicodeBOM = '\uFEFF'
Expand Down Expand Up @@ -584,6 +585,23 @@ object Parser {
}
}

def parseInputValue(rawValue: String): Either[ParsingError, InputValue] = {
val sm = SourceMapper(rawValue)
Try(value.parse(rawValue)).toEither.left
.map(ex => ParsingError(s"Internal parsing error", innerThrowable = Some(ex)))
.flatMap {
case Left(error) =>
Left(
ParsingError(
s"Parsing error at offset ${error.failedAtOffset}, expected: ${error.expected.toList.mkString(";")}",
Some(sm.getLocation(error.failedAtOffset))
)
)

case Right(_, result) => Right(result)
}
}

/**
* Checks if the query is valid, if not returns an error string.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import caliban.CalibanError.ExecutionError
import caliban.InputValue
import caliban.Value._
import caliban.schema.macros.Macros
import caliban.schema.Annotations.GQLDefault
import caliban.schema.Annotations.GQLName

import scala.deriving.Mirror
Expand Down Expand Up @@ -54,7 +55,8 @@ trait ArgBuilderDerivation {
input match {
case InputValue.ObjectValue(fields) =>
val finalLabel = annotations.getOrElse(label, Nil).collectFirst { case GQLName(name) => name }.getOrElse(label)
fields.get(finalLabel).fold(builder.buildMissing)(builder.build)
val default = annotations.getOrElse(label, Nil).collectFirst { case GQLDefault(v) => v }
fields.get(finalLabel).fold(builder.buildMissing(default))(builder.build)
case value => builder.build(value)
}
}.foldRight[Either[ExecutionError, Tuple]](Right(EmptyTuple)) { case (item, acc) =>
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/scala-3/caliban/schema/SchemaDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ trait SchemaDerivation[R] {
.map { case (label, _, schema, _) =>
val fieldAnnotations = paramAnnotations.getOrElse(label, Nil)
__InputValue(
getName(paramAnnotations.getOrElse(label, Nil), label),
getName(fieldAnnotations, label),
getDescription(fieldAnnotations),
() =>
if (schema.optional) schema.toType_(isInput, isSubscription)
else makeNonNull(schema.toType_(isInput, isSubscription)),
None,
getDefaultValue(fieldAnnotations),
Some(fieldAnnotations.collect { case GQLDirective(dir) => dir }).filter(_.nonEmpty)
)
},
Expand Down Expand Up @@ -216,5 +216,8 @@ trait SchemaDerivation[R] {
private def getDirectives(annotations: Seq[Any]): List[Directive] =
annotations.collect { case GQLDirective(dir) => dir }.toList

private def getDefaultValue(annotations: Seq[Any]): Option[String] =
annotations.collectFirst { case GQLDefault(v) => v }

inline given gen[A]: Schema[R, A] = derived
}
16 changes: 9 additions & 7 deletions core/src/main/scala/caliban/Rendering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,21 @@ object Rendering {
else ""}${renderDirectives(field.directives)}"

private def renderInputValue(inputValue: __InputValue): String =
s"${inputValue.name}: ${renderTypeName(inputValue.`type`())}${inputValue.defaultValue
.fold("")(d => s" = $d")}${renderDirectives(inputValue.directives)}"
s"${inputValue.name}: ${renderTypeName(inputValue.`type`())}${renderDefaultValue(inputValue)}${renderDirectives(inputValue.directives)}"

private def renderEnumValue(v: __EnumValue): String =
s"${renderDescription(v.description)}${v.name}${if (v.isDeprecated)
s" @deprecated${v.deprecationReason.fold("")(reason => s"""(reason: "$reason")""")}"
else ""}"

private def renderArguments(arguments: List[__InputValue]): String = arguments match {
case Nil => ""
case list =>
s"(${list.map(a => s"${renderDescription(a.description, newline = false)}${a.name}: ${renderTypeName(a.`type`())}").mkString(", ")})"
}
private def renderDefaultValue(a: __InputValue): String = a.defaultValue.fold("")(d => s" = $d")

private def renderArguments(arguments: List[__InputValue]): String =
arguments match {
case Nil => ""
case list =>
s"(${list.map(a => s"${renderDescription(a.description, newline = false)}${a.name}: ${renderTypeName(a.`type`())}${renderDefaultValue(a)}").mkString(", ")})"
}

private def isBuiltinScalar(name: String): Boolean =
name == "Int" || name == "Float" || name == "String" || name == "Boolean" || name == "ID"
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala/caliban/schema/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ object Annotations {
* Annotation to make a union or interface redirect to a value type
*/
case class GQLValueType() extends StaticAnnotation

/**
* Annotation to specify the default value of an input field
*/
case class GQLDefault(value: String) extends StaticAnnotation
}
8 changes: 7 additions & 1 deletion core/src/main/scala/caliban/schema/ArgBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package caliban.schema
import caliban.CalibanError.ExecutionError
import caliban.InputValue
import caliban.Value._
import caliban.parsing.Parser
import zio.Chunk

import java.time.format.DateTimeFormatter
Expand Down Expand Up @@ -38,7 +39,12 @@ trait ArgBuilder[T] { self =>
* By default, this delegates to [[build]], passing it NullValue.
* Fails with an [[caliban.CalibanError.ExecutionError]] if it was impossible to build the value.
*/
def buildMissing: Either[ExecutionError, T] = build(NullValue)
def buildMissing(default: Option[String]): Either[ExecutionError, T] =
default
.map(
Parser.parseInputValue(_).flatMap(build(_)).left.map(e => ExecutionError(e.getMessage()))
)
.getOrElse(build(NullValue))

/**
* Builds a new `ArgBuilder` of `A` from an existing `ArgBuilder` of `T` and a function from `T` to `A`.
Expand Down
7 changes: 4 additions & 3 deletions core/src/main/scala/caliban/schema/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ trait GenericSchema[R] extends SchemaDerivation[R] with TemporalSchema {
ev2: Schema[RB, B]
): Schema[RA with RB, A => B] =
new Schema[RA with RB, A => B] {
private lazy val inputType = ev1.toType_(true)
private val unwrappedArgumentName = "value"
override def arguments: List[__InputValue] =
private lazy val inputType = ev1.toType_(true)
private val unwrappedArgumentName = "value"
override def arguments: List[__InputValue] =
inputType.inputFields.getOrElse(
handleInput(List.empty[__InputValue])(
List(
Expand All @@ -375,6 +375,7 @@ trait GenericSchema[R] extends SchemaDerivation[R] with TemporalSchema {
)
)
)

override def optional: Boolean = ev2.optional
override def toType(isInput: Boolean, isSubscription: Boolean): __Type = ev2.toType_(isInput, isSubscription)

Expand Down
132 changes: 132 additions & 0 deletions core/src/main/scala/caliban/validation/DefaultValueValidator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package caliban.validation

import caliban.CalibanError.ValidationError
import caliban.InputValue
import caliban.InputValue._
import caliban.Value
import caliban.Value._
import caliban.introspection.adt._
import caliban.introspection.adt.__TypeKind._
import caliban.parsing.Parser
import zio.IO

object DefaultValueValidator {
def validateDefaultValue(field: __InputValue, errorContext: String): IO[ValidationError, Unit] =
IO.whenCase(field.defaultValue) { case Some(v) =>
for {
value <-
IO.fromEither(Parser.parseInputValue(v))
.mapError(e =>
ValidationError(
s"$errorContext failed to parse default value: ${e.msg}",
"The default value for a field must be written using GraphQL input syntax."
)
)
_ <- Validator.validateInputValues(field, value)
_ <- validateInputTypes(field, value, errorContext)
} yield ()
}

def validateInputTypes(
inputValue: __InputValue,
argValue: InputValue,
errorContext: String
): IO[ValidationError, Unit] = validateType(inputValue.`type`(), argValue, errorContext)

def validateType(inputType: __Type, argValue: InputValue, errorContext: String): IO[ValidationError, Unit] =
inputType.kind match {
case NON_NULL =>
argValue match {
case NullValue =>
failValidation(s"$errorContext is null", "Input field was null but was supposed to be non-null.")
case x => validateType(inputType.ofType.getOrElse(inputType), x, errorContext)
}
case LIST =>
argValue match {
case ListValue(values) =>
IO.foreach_(values)(v =>
validateType(inputType.ofType.getOrElse(inputType), v, s"List item in $errorContext")
)
case _ =>
failValidation(s"$errorContext has invalid type: $argValue", "Input field was supposed to be a list.")
}

case INPUT_OBJECT =>
argValue match {
case ObjectValue(fields) =>
IO.foreach_(inputType.inputFields.getOrElse(List.empty)) { f =>
val value =
fields.collectFirst({ case (name, fieldValue) if name == f.name => fieldValue }).getOrElse(NullValue)
validateType(f.`type`(), value, s"Field ${f.name} in $errorContext")
}
case _ =>
failValidation(
s"$errorContext has invalid type: $argValue",
"Input field was supposed to be an input object."
)
}
case ENUM =>
argValue match {
case EnumValue(value) =>
val possible = inputType
.enumValues(__DeprecatedArgs(Some(true)))
.getOrElse(List.empty)
.map(_.name)
val exists = possible.exists(_ == value)

IO.unless(exists)(
failValidation(
s"$errorContext has invalid enum value: $value",
s"Was supposed to be one of ${possible.mkString(", ")}"
)
)
case _ =>
failValidation(
s"$errorContext has invalid type: $argValue",
"Input field was supposed to be an enum value."
)
}
case SCALAR => validateScalar(inputType, argValue, errorContext)
case x =>
failValidation(
s"$errorContext has invalid type $inputType",
"Input value is invalid, should be a scalar, list or input object."
)
}

def validateScalar(inputType: __Type, argValue: InputValue, errorContext: String) =
inputType.name.getOrElse("") match {
case "String" =>
argValue match {
case StringValue(value) =>
IO.unit
case t => failValidation(s"$errorContext has invalid type $t", "Expected 'String'")
}
case "ID" =>
argValue match {
case StringValue(value) =>
IO.unit
case t => failValidation(s"$errorContext has invalid type $t", "Expected 'ID'")
}
case "Int" =>
argValue match {
case _: Value.IntValue => IO.unit
case t => failValidation(s"$errorContext has invalid type $t", "Expected 'Int'")
}
case "Float" =>
argValue match {
case _: Value.FloatValue => IO.unit
case t => failValidation(s"$errorContext has invalid type $t", "Expected 'Float'")
}
case "Boolean" =>
argValue match {
case BooleanValue(value) => IO.unit
case t => failValidation(s"$errorContext has invalid type $t", "Expected 'Boolean'")
}
// We can't really validate custom scalars here (since we can't summon a correct ArgBuilder instance), so just pass them along
case x => IO.unit
}

def failValidation[T](msg: String, explanatoryText: String): IO[ValidationError, T] =
IO.fail(ValidationError(msg, explanatoryText))
}
10 changes: 7 additions & 3 deletions core/src/main/scala/caliban/validation/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,10 @@ object Validator {
)
)

private def validateInputValues(inputValue: __InputValue, argValue: InputValue): IO[ValidationError, Unit] = {
private[caliban] def validateInputValues(
inputValue: __InputValue,
argValue: InputValue
): IO[ValidationError, Unit] = {
val t = inputValue.`type`()
val inputType = if (t.kind == __TypeKind.NON_NULL) t.ofType.getOrElse(t) else t
val inputFields = inputType.inputFields.getOrElse(Nil)
Expand Down Expand Up @@ -603,8 +606,8 @@ object Validator {
}

def validateFields(fields: List[__InputValue]): IO[ValidationError, Unit] =
noDuplicateInputValueName(fields, inputObjectContext) <*
IO.foreach_(fields)(validateInputValue(_, inputObjectContext))
IO.foreach_(fields)(validateInputValue(_, inputObjectContext)) &>
noDuplicateInputValueName(fields, inputObjectContext)

t.inputFields match {
case None | Some(Nil) =>
Expand All @@ -619,6 +622,7 @@ object Validator {
private[caliban] def validateInputValue(inputValue: __InputValue, errorContext: String): IO[ValidationError, Unit] = {
val fieldContext = s"InputValue '${inputValue.name}' of $errorContext"
for {
_ <- DefaultValueValidator.validateDefaultValue(inputValue, fieldContext)
_ <- doesNotStartWithUnderscore(inputValue, fieldContext)
_ <- onlyInputType(inputValue.`type`(), fieldContext)
} yield ()
Expand Down
Loading

0 comments on commit 54eba28

Please sign in to comment.