diff --git a/core/src/main/scala/caliban/CalibanError.scala b/core/src/main/scala/caliban/CalibanError.scala index 5b203e147..31b6cad4a 100644 --- a/core/src/main/scala/caliban/CalibanError.scala +++ b/core/src/main/scala/caliban/CalibanError.scala @@ -26,10 +26,17 @@ object CalibanError { /** * Describes an error that happened while executing a query. */ - case class ExecutionError(msg: String, fieldName: Option[String] = None, innerThrowable: Option[Throwable] = None) - extends CalibanError { + case class ExecutionError( + msg: String, + path: List[Either[String, Int]] = Nil, + innerThrowable: Option[Throwable] = None + ) extends CalibanError { override def toString: String = { - val field = fieldName.fold("")(f => s" on field '$f'") + val pathString = path.map { + case Left(value) => value + case Right(value) => value.toString + }.mkString(" > ") + val field = if (pathString.isEmpty) "" else s" on field '$pathString'" val inner = innerThrowable.fold("")(e => s" with ${e.toString}") s"Execution error$field: $msg$inner" } diff --git a/core/src/main/scala/caliban/GraphQL.scala b/core/src/main/scala/caliban/GraphQL.scala index 270ddf88a..e8db2d653 100644 --- a/core/src/main/scala/caliban/GraphQL.scala +++ b/core/src/main/scala/caliban/GraphQL.scala @@ -2,12 +2,14 @@ package caliban import caliban.Rendering.renderTypes import caliban.execution.Executor -import caliban.execution.QueryAnalyzer.QueryAnalyzer import caliban.introspection.Introspector import caliban.parsing.Parser +import caliban.parsing.adt.OperationType import caliban.schema.RootSchema.Operation import caliban.schema._ import caliban.validation.Validator +import caliban.wrappers.Wrapper +import caliban.wrappers.Wrapper._ import zio.{ IO, URIO } /** @@ -19,7 +21,7 @@ import zio.{ IO, URIO } trait GraphQL[-R] { self => protected val schema: RootSchema[R] - protected val queryAnalyzers: List[QueryAnalyzer[R]] + protected val wrappers: List[Wrapper[R]] private lazy val rootType: RootType = RootType(schema.query.opType, schema.mutation.map(_.opType), schema.subscription.map(_.opType)) @@ -31,20 +33,23 @@ trait GraphQL[-R] { self => operationName: Option[String], variables: Map[String, InputValue], skipValidation: Boolean - ): URIO[R, GraphQLResponse[CalibanError]] = { - - val prepare = for { - document <- Parser.parseQuery(query) - intro = Introspector.isIntrospection(document) - typeToValidate = if (intro) introspectionRootType else rootType - schemaToExecute = if (intro) introspectionRootSchema else schema - _ <- IO.when(!skipValidation)(Validator.validate(document, typeToValidate)) - } yield (document, schemaToExecute) - - prepare.foldM( - Executor.fail, - req => Executor.executeRequest(req._1, req._2, operationName, variables, queryAnalyzers) - ) + ): URIO[R, GraphQLResponse[CalibanError]] = decompose(wrappers).flatMap { + case (overallWrappers, parsingWrappers, validationWrappers, executionWrappers, fieldWrappers) => + wrap((for { + doc <- wrap(Parser.parseQuery(query))(parsingWrappers, query) + intro = Introspector.isIntrospection(doc) + typeToValidate = if (intro) introspectionRootType else rootType + schemaToExecute = if (intro) introspectionRootSchema else schema + validate = Validator.prepare(doc, typeToValidate, schemaToExecute, operationName, variables, skipValidation) + request <- wrap(validate)(validationWrappers, doc) + op = request.operationType match { + case OperationType.Query => schemaToExecute.query + case OperationType.Mutation => schemaToExecute.mutation.getOrElse(schemaToExecute.query) + case OperationType.Subscription => schemaToExecute.subscription.getOrElse(schemaToExecute.query) + } + execute = Executor.executeRequest(request, op.plan, variables, fieldWrappers) + result <- wrap(execute)(executionWrappers, request) + } yield result).catchAll(Executor.fail))(overallWrappers, query) } /** @@ -74,16 +79,22 @@ trait GraphQL[-R] { self => self.execute(query, operationName, variables, skipValidation) /** - * Attaches a function that will analyze each query before execution, possibly modify or reject it. - * @param queryAnalyzer a function from `Field` to `ZIO[R, CalibanError, Field]` + * Attaches a function that will wrap one of the stages of query processing + * (parsing, validation, execution, field execution or overall). + * @param wrapper a wrapping function * @return a new GraphQL API */ - final def withQueryAnalyzer[R2 <: R](queryAnalyzer: QueryAnalyzer[R2]): GraphQL[R2] = + final def withWrapper[R2 <: R](wrapper: Wrapper[R2]): GraphQL[R2] = new GraphQL[R2] { - override val schema: RootSchema[R2] = self.schema - override val queryAnalyzers: List[QueryAnalyzer[R2]] = queryAnalyzer :: self.queryAnalyzers + override val schema: RootSchema[R2] = self.schema + override val wrappers: List[Wrapper[R2]] = wrapper :: self.wrappers } + /** + * A symbolic alias for `withWrapper`. + */ + final def @@[R2 <: R](wrapper: Wrapper[R2]): GraphQL[R2] = withWrapper(wrapper) + /** * Merges this GraphQL API with another GraphQL API. * In case of conflicts (same field declared on both APIs), fields from `that` API will be used. @@ -92,8 +103,8 @@ trait GraphQL[-R] { self => */ final def combine[R1 <: R](that: GraphQL[R1]): GraphQL[R1] = new GraphQL[R1] { - override val schema: RootSchema[R1] = self.schema |+| that.schema - override val queryAnalyzers: List[QueryAnalyzer[R1]] = self.queryAnalyzers ++ that.queryAnalyzers + override val schema: RootSchema[R1] = self.schema |+| that.schema + override protected val wrappers: List[Wrapper[R1]] = self.wrappers ++ that.wrappers } /** @@ -124,7 +135,7 @@ trait GraphQL[-R] { self => name => self.schema.subscription.map(m => m.copy(opType = m.opType.copy(name = Some(name)))) ) ) - override protected val queryAnalyzers: List[QueryAnalyzer[R]] = self.queryAnalyzers + override protected val wrappers: List[Wrapper[R]] = self.wrappers } } @@ -146,6 +157,6 @@ object GraphQL { resolver.mutationResolver.map(r => Operation(mutationSchema.toType(), mutationSchema.resolve(r))), resolver.subscriptionResolver.map(r => Operation(subscriptionSchema.toType(), subscriptionSchema.resolve(r))) ) - val queryAnalyzers: List[QueryAnalyzer[R]] = Nil + val wrappers: List[Wrapper[R]] = Nil } } diff --git a/core/src/main/scala/caliban/GraphQLInterpreter.scala b/core/src/main/scala/caliban/GraphQLInterpreter.scala index 456ab0445..ce50c6309 100644 --- a/core/src/main/scala/caliban/GraphQLInterpreter.scala +++ b/core/src/main/scala/caliban/GraphQLInterpreter.scala @@ -33,7 +33,7 @@ trait GraphQLInterpreter[-R, +E] { self => * @return a new GraphQL interpreter with error type `E2` */ final def mapError[E2](f: E => E2): GraphQLInterpreter[R, E2] = - wrapExecutionWith(_.map(res => GraphQLResponse(res.data, res.errors.map(f)))) + wrapExecutionWith(_.map(res => GraphQLResponse(res.data, res.errors.map(f), res.extensions))) /** * Eliminates the ZIO environment R requirement of the interpreter. diff --git a/core/src/main/scala/caliban/GraphQLResponse.scala b/core/src/main/scala/caliban/GraphQLResponse.scala index 0d74dd7b7..544759636 100644 --- a/core/src/main/scala/caliban/GraphQLResponse.scala +++ b/core/src/main/scala/caliban/GraphQLResponse.scala @@ -1,11 +1,13 @@ package caliban +import caliban.CalibanError.ExecutionError +import caliban.ResponseValue.ObjectValue import caliban.interop.circe._ /** * Represents the result of a GraphQL query, containing a data object and a list of errors. */ -case class GraphQLResponse[+E](data: ResponseValue, errors: List[E]) +case class GraphQLResponse[+E](data: ResponseValue, errors: List[E], extensions: Option[ObjectValue] = None) object GraphQLResponse { implicit def circeEncoder[F[_]: IsCirceEncoder, E]: F[GraphQLResponse[E]] = @@ -17,11 +19,30 @@ private object GraphQLResponseCirce { import io.circe.syntax._ val graphQLResponseEncoder: Encoder[GraphQLResponse[Any]] = Encoder .instance[GraphQLResponse[Any]] { - case GraphQLResponse(data, Nil) => Json.obj("data" -> data.asJson) - case GraphQLResponse(data, errors) => + case GraphQLResponse(data, Nil, None) => Json.obj("data" -> data.asJson) + case GraphQLResponse(data, Nil, Some(extensions)) => + Json.obj("data" -> data.asJson, "extensions" -> extensions.asInstanceOf[ResponseValue].asJson) + case GraphQLResponse(data, errors, None) => + Json.obj("data" -> data.asJson, "errors" -> Json.fromValues(errors.map(handleError))) + case GraphQLResponse(data, errors, Some(extensions)) => Json.obj( - "data" -> data.asJson, - "errors" -> Json.fromValues(errors.map(err => Json.obj("message" -> Json.fromString(err.toString)))) + "data" -> data.asJson, + "errors" -> Json.fromValues(errors.map(handleError)), + "extensions" -> extensions.asInstanceOf[ResponseValue].asJson ) } + + private def handleError(err: Any): Json = + err match { + case ExecutionError(_, path, _) if path.nonEmpty => + Json.obj( + "message" -> Json.fromString(err.toString), + "path" -> Json.fromValues(path.map { + case Left(value) => Json.fromString(value) + case Right(value) => Json.fromInt(value) + }) + ) + case _ => Json.obj("message" -> Json.fromString(err.toString)) + } + } diff --git a/core/src/main/scala/caliban/execution/ExecutionRequest.scala b/core/src/main/scala/caliban/execution/ExecutionRequest.scala new file mode 100644 index 000000000..df0cbbbc1 --- /dev/null +++ b/core/src/main/scala/caliban/execution/ExecutionRequest.scala @@ -0,0 +1,9 @@ +package caliban.execution + +import caliban.parsing.adt.{ OperationType, VariableDefinition } + +case class ExecutionRequest( + field: Field, + operationType: OperationType, + variableDefinitions: List[VariableDefinition] +) diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index 398e85415..5f3f89f49 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -1,15 +1,13 @@ package caliban.execution +import scala.annotation.tailrec import scala.collection.immutable.ListMap -import caliban.CalibanError.ExecutionError import caliban.ResponseValue._ import caliban.Value._ -import caliban.execution.QueryAnalyzer.QueryAnalyzer -import caliban.parsing.adt.ExecutableDefinition.{ FragmentDefinition, OperationDefinition } -import caliban.parsing.adt.OperationType.{ Mutation, Query, Subscription } import caliban.parsing.adt._ import caliban.schema.Step._ -import caliban.schema.{ GenericSchema, ReducedStep, RootSchema, Step } +import caliban.schema.{ GenericSchema, ReducedStep, Step } +import caliban.wrappers.Wrapper.FieldWrapper import caliban.{ CalibanError, GraphQLResponse, InputValue, ResponseValue } import zio._ import zquery.ZQuery @@ -18,57 +16,24 @@ object Executor { /** * Executes the given query against a schema. It returns either an [[caliban.CalibanError.ExecutionError]] or a [[ResponseValue]]. - * @param document the parsed query - * @param schema the schema to use to run the query - * @param operationName the operation to run in case the query contains multiple operations. - * @param variables a list of variables. + * @param request a request object containing all information needed + * @param plan an execution plan + * @param variables a list of variables + * @param fieldWrappers a list of field wrappers */ def executeRequest[R]( - document: Document, - schema: RootSchema[R], - operationName: Option[String] = None, + request: ExecutionRequest, + plan: Step[R], variables: Map[String, InputValue] = Map(), - queryAnalyzers: List[QueryAnalyzer[R]] = Nil + fieldWrappers: List[FieldWrapper[R]] = Nil ): URIO[R, GraphQLResponse[CalibanError]] = { - val fragments = document.definitions.collect { - case fragment: FragmentDefinition => fragment.name -> fragment - }.toMap - val operation = operationName match { - case Some(name) => - document.definitions.collectFirst { case op: OperationDefinition if op.name.contains(name) => op } - .toRight(s"Unknown operation $name.") - case None => - document.definitions.collect { case op: OperationDefinition => op } match { - case head :: Nil => Right(head) - case _ => Left("Operation name is required.") - } + val allowParallelism = request.operationType match { + case OperationType.Query => true + case OperationType.Mutation => false + case OperationType.Subscription => false } - operation match { - case Left(error) => fail(ExecutionError(error)) - case Right(op) => - val getOperationType = op.operationType match { - case Query => IO.succeed((schema.query, true)) - case Mutation => - schema.mutation match { - case Some(m) => IO.succeed((m, false)) - case None => IO.fail(ExecutionError("Mutations are not supported on this schema")) - } - case Subscription => - schema.subscription match { - case Some(m) => IO.succeed((m, false)) - case None => IO.fail(ExecutionError("Subscriptions are not supported on this schema")) - } - } - (for { - (operationType, allowParallelism) <- getOperationType - root <- ZIO - .foldLeft(queryAnalyzers)(Field(op.selectionSet, fragments, variables, operationType.opType)) { - case (field, analyzer) => analyzer(field) - } - result <- executePlan(operationType.plan, root, op.variableDefinitions, variables, allowParallelism) - } yield result).catchAll(fail) - } + executePlan(plan, request.field, request.variableDefinitions, variables, allowParallelism, fieldWrappers) } private[caliban] def fail(error: CalibanError): UIO[GraphQLResponse[CalibanError]] = @@ -79,13 +44,15 @@ object Executor { root: Field, variableDefinitions: List[VariableDefinition], variableValues: Map[String, InputValue], - allowParallelism: Boolean + allowParallelism: Boolean, + fieldWrappers: List[FieldWrapper[R]] ): URIO[R, GraphQLResponse[CalibanError]] = { def reduceStep( step: Step[R], currentField: Field, - arguments: Map[String, InputValue] + arguments: Map[String, InputValue], + path: List[Either[String, Int]] ): ReducedStep[R] = step match { case s @ PureStep(value) => @@ -101,38 +68,55 @@ object Executor { obj.fold(s)(PureStep(_)) case _ => s } - case FunctionStep(step) => reduceStep(step(arguments), currentField, Map()) - case ListStep(steps) => reduceList(steps.map(reduceStep(_, currentField, arguments))) + case FunctionStep(step) => reduceStep(step(arguments), currentField, Map(), path) + case ListStep(steps) => + reduceList(steps.zipWithIndex.map { + case (step, i) => reduceStep(step, currentField, arguments, Right(i) :: path) + }) case ObjectStep(objectName, fields) => val mergedFields = mergeFields(currentField, objectName) val items = mergedFields.map { - case Field(name @ "__typename", _, _, alias, _, _, _) => - alias.getOrElse(name) -> PureStep(StringValue(objectName)) + case f @ Field(name @ "__typename", _, _, alias, _, _, _) => + (alias.getOrElse(name), PureStep(StringValue(objectName)), fieldInfo(f, path)) case f @ Field(name, _, _, alias, _, _, args) => val arguments = resolveVariables(args, variableDefinitions, variableValues) - alias.getOrElse(name) -> + ( + alias.getOrElse(name), fields .get(name) - .fold(NullStep: ReducedStep[R])(reduceStep(_, f, arguments)) + .fold(NullStep: ReducedStep[R])(reduceStep(_, f, arguments, Left(alias.getOrElse(name)) :: path)), + fieldInfo(f, path) + ) } - reduceObject(items) + reduceObject(items, fieldWrappers) case QueryStep(inner) => ReducedStep.QueryStep( - inner.bimap( - GenericSchema.effectfulExecutionError(currentField.name, _), - reduceStep(_, currentField, arguments) - ) + inner.bimap(GenericSchema.effectfulExecutionError(path, _), reduceStep(_, currentField, arguments, path)) ) case StreamStep(stream) => ReducedStep.StreamStep( - stream.bimap( - GenericSchema.effectfulExecutionError(currentField.name, _), - reduceStep(_, currentField, arguments) - ) + stream.bimap(GenericSchema.effectfulExecutionError(path, _), reduceStep(_, currentField, arguments, path)) ) } def makeQuery(step: ReducedStep[R], errors: Ref[List[CalibanError]]): ZQuery[R, Nothing, ResponseValue] = { + + @tailrec + def wrap(query: ZQuery[R, Nothing, ResponseValue])( + wrappers: List[FieldWrapper[R]], + fieldInfo: FieldInfo + ): ZQuery[R, Nothing, ResponseValue] = + wrappers match { + case Nil => query + case wrapper :: tail => + wrap( + wrapper + .f(query, fieldInfo) + .foldM(error => ZQuery.fromEffect(errors.update(error :: _)).map(_ => NullValue), ZQuery.succeed) + )(tail, fieldInfo) + + } + def loop(step: ReducedStep[R]): ZQuery[R, Nothing, ResponseValue] = step match { case PureStep(value) => ZQuery.succeed(value) @@ -140,7 +124,7 @@ object Executor { val queries = steps.map(loop) (if (allowParallelism) ZQuery.collectAllPar(queries) else ZQuery.collectAll(queries)).map(ListValue) case ReducedStep.ObjectStep(steps) => - val queries = steps.map { case (name, field) => loop(field).map(name -> _) } + val queries = steps.map { case (name, step, info) => wrap(loop(step))(fieldWrappers, info).map(name -> _) } (if (allowParallelism) ZQuery.collectAllPar(queries) else ZQuery.collectAll(queries)).map(ObjectValue) case ReducedStep.QueryStep(step) => step.foldM( @@ -157,7 +141,7 @@ object Executor { for { errors <- Ref.make(List.empty[CalibanError]) - reduced = reduceStep(plan, root, Map()) + reduced = reduceStep(plan, root, Map(), Nil) query = makeQuery(reduced, errors) result <- query.run resultErrors <- errors.get @@ -199,15 +183,21 @@ object Executor { .toList } + private def fieldInfo(field: Field, path: List[Either[String, Int]]): FieldInfo = + FieldInfo(field.alias.getOrElse(field.name), path, field.parentType, field.fieldType) + private def reduceList[R](list: List[ReducedStep[R]]): ReducedStep[R] = if (list.forall(_.isInstanceOf[PureStep])) PureStep(ListValue(list.asInstanceOf[List[PureStep]].map(_.value))) else ReducedStep.ListStep(list) - private def reduceObject[R](items: List[(String, ReducedStep[R])]): ReducedStep[R] = - if (items.map(_._2).forall(_.isInstanceOf[PureStep])) - PureStep(ObjectValue(items.asInstanceOf[List[(String, PureStep)]].map { - case (k, v) => k -> v.value + private def reduceObject[R]( + items: List[(String, ReducedStep[R], FieldInfo)], + fieldWrappers: List[FieldWrapper[R]] + ): ReducedStep[R] = + if (!fieldWrappers.exists(_.wrapPureValues) && items.map(_._2).forall(_.isInstanceOf[PureStep])) + PureStep(ObjectValue(items.asInstanceOf[List[(String, PureStep, FieldInfo)]].map { + case (k, v, _) => (k, v.value) })) else ReducedStep.ObjectStep(items) diff --git a/core/src/main/scala/caliban/execution/Field.scala b/core/src/main/scala/caliban/execution/Field.scala index bef53d886..4b853e1c7 100644 --- a/core/src/main/scala/caliban/execution/Field.scala +++ b/core/src/main/scala/caliban/execution/Field.scala @@ -27,15 +27,16 @@ object Field { ): Field = { def loop(selectionSet: List[Selection], fieldType: __Type): Field = { + val innerType = Types.innerType(fieldType) val (fields, cFields) = selectionSet.map { case f @ F(alias, name, arguments, _, selectionSet) if checkDirectives(f.directives, variableValues) => - val t = fieldType + val t = innerType .fields(__DeprecatedArgs(Some(true))) .flatMap(_.find(_.name == name)) - .fold(Types.string)(f => Types.innerType(f.`type`())) // default only case where it's not found is __typename + .fold(Types.string)(_.`type`()) // default only case where it's not found is __typename val field = loop(selectionSet, t) ( - List(Field(name, t, Some(fieldType), alias, field.fields, field.conditionalFields, arguments)), + List(Field(name, t, Some(innerType), alias, field.fields, field.conditionalFields, arguments)), Map.empty[String, List[Field]] ) case FragmentSpread(name, directives) if checkDirectives(directives, variableValues) => @@ -44,12 +45,12 @@ object Field { .get(name) .fold(default) { f => val t = - fieldType.possibleTypes.flatMap(_.find(_.name.contains(f.typeCondition.name))).getOrElse(fieldType) + innerType.possibleTypes.flatMap(_.find(_.name.contains(f.typeCondition.name))).getOrElse(fieldType) val field = loop(f.selectionSet, t) (Nil, combineMaps(List(field.conditionalFields, Map(f.typeCondition.name -> field.fields)))) } case InlineFragment(typeCondition, directives, selectionSet) if checkDirectives(directives, variableValues) => - val t = fieldType.possibleTypes + val t = innerType.possibleTypes .flatMap(_.find(_.name.exists(typeCondition.map(_.name).contains))) .getOrElse(fieldType) val field = loop(selectionSet, t) diff --git a/core/src/main/scala/caliban/execution/FieldInfo.scala b/core/src/main/scala/caliban/execution/FieldInfo.scala new file mode 100644 index 000000000..5a04bba2d --- /dev/null +++ b/core/src/main/scala/caliban/execution/FieldInfo.scala @@ -0,0 +1,5 @@ +package caliban.execution + +import caliban.introspection.adt.__Type + +case class FieldInfo(fieldName: String, path: List[Either[String, Int]], parentType: Option[__Type], returnType: __Type) diff --git a/core/src/main/scala/caliban/execution/QueryAnalyzer.scala b/core/src/main/scala/caliban/execution/QueryAnalyzer.scala deleted file mode 100644 index bf3542ba3..000000000 --- a/core/src/main/scala/caliban/execution/QueryAnalyzer.scala +++ /dev/null @@ -1,66 +0,0 @@ -package caliban.execution - -import caliban.{ CalibanError, GraphQL } -import caliban.CalibanError.ValidationError -import zio.{ IO, ZIO } - -object QueryAnalyzer { - - /** - * A query analyzer is a function that takes a root [[Field]] and returns a new root [[Field]] or fails with a [[CalibanError]]. - * In case of failure, the query will be rejected before execution. - * The environment `R` can be used to "inject" some data that will be used by the resolvers (e.g. query cost). - */ - type QueryAnalyzer[-R] = Field => ZIO[R, CalibanError, Field] - - /** - * Attaches to the given GraphQL API definition a function that checks that each query depth is under a given max. - * @param maxDepth the max allowed depth for a query - * @param api a GraphQL API definition - * @return a new GraphQL API definition - */ - def maxDepth[R, E](maxDepth: Int)(api: GraphQL[R]): GraphQL[R] = - api.withQueryAnalyzer(checkMaxDepth(maxDepth)) - - /** - * Checks that the given field's depth is under a given max - * @param maxDepth the max allowed depth for the field - */ - def checkMaxDepth(maxDepth: Int): QueryAnalyzer[Any] = { field => - val depth = calculateDepth(field) - if (depth > maxDepth) IO.fail(ValidationError(s"Query is too deep: $depth. Max depth: $maxDepth.", "")) - else IO.succeed(field) - } - - def calculateDepth(field: Field): Int = { - val children = field.fields ++ field.conditionalFields.values.flatten - val childrenDepth = if (children.isEmpty) 0 else children.map(calculateDepth).max - childrenDepth + (if (field.name.nonEmpty) 1 else 0) - } - - /** - * Attaches to the given GraphQL API definition a function that checks that each query has a limited number of fields. - * @param maxFields the max allowed number of fields for a query - * @param api a GraphQL API definition - * @return a new GraphQL API definition - */ - def maxFields[R, E](maxFields: Int)(api: GraphQL[R]): GraphQL[R] = - api.withQueryAnalyzer(checkMaxFields(maxFields)) - - /** - * Checks that the given field has a limited number of fields - * @param maxFields the max allowed number of fields inside the given field - */ - def checkMaxFields(maxFields: Int): QueryAnalyzer[Any] = { field => - val fields = countFields(field) - if (fields > maxFields) IO.fail(ValidationError(s"Query has too many fields: $fields. Max fields: $maxFields.", "")) - else IO.succeed(field) - } - - def countFields(field: Field): Int = - innerFields(field.fields) + (if (field.conditionalFields.isEmpty) 0 - else field.conditionalFields.values.map(innerFields).max) - - private def innerFields(fields: List[Field]): Int = fields.length + fields.map(countFields).sum - -} diff --git a/core/src/main/scala/caliban/schema/Schema.scala b/core/src/main/scala/caliban/schema/Schema.scala index 739676b1c..c4008320f 100644 --- a/core/src/main/scala/caliban/schema/Schema.scala +++ b/core/src/main/scala/caliban/schema/Schema.scala @@ -418,8 +418,9 @@ trait DerivationSchema[R] { object GenericSchema { - def effectfulExecutionError(fieldName: String, e: Throwable): ExecutionError = e match { - case e: ExecutionError => e - case other => ExecutionError("Effect failure", Some(fieldName), Some(other)) - } + def effectfulExecutionError(path: List[Either[String, Int]], e: Throwable): ExecutionError = + e match { + case e: ExecutionError => e + case other => ExecutionError("Effect failure", path.reverse, Some(other)) + } } diff --git a/core/src/main/scala/caliban/schema/Step.scala b/core/src/main/scala/caliban/schema/Step.scala index 6c2fb6579..c875c8254 100644 --- a/core/src/main/scala/caliban/schema/Step.scala +++ b/core/src/main/scala/caliban/schema/Step.scala @@ -3,6 +3,7 @@ package caliban.schema import caliban.CalibanError.ExecutionError import caliban.{ InputValue, ResponseValue } import caliban.Value.NullValue +import caliban.execution.FieldInfo import zio.stream.ZStream import zquery.ZQuery @@ -36,7 +37,7 @@ sealed trait ReducedStep[-R] object ReducedStep { case class ListStep[-R](steps: List[ReducedStep[R]]) extends ReducedStep[R] - case class ObjectStep[-R](fields: List[(String, ReducedStep[R])]) extends ReducedStep[R] + case class ObjectStep[-R](fields: List[(String, ReducedStep[R], FieldInfo)]) extends ReducedStep[R] case class QueryStep[-R](query: ZQuery[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R] case class StreamStep[-R](inner: ZStream[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R] diff --git a/core/src/main/scala/caliban/validation/Validator.scala b/core/src/main/scala/caliban/validation/Validator.scala index 369e9aeab..a78186495 100644 --- a/core/src/main/scala/caliban/validation/Validator.scala +++ b/core/src/main/scala/caliban/validation/Validator.scala @@ -1,17 +1,18 @@ package caliban.validation import caliban.CalibanError.ValidationError -import caliban.{ InputValue, Rendering } +import caliban.InputValue.VariableValue +import caliban.Value.NullValue +import caliban.execution.{ ExecutionRequest, Field => F } import caliban.introspection.Introspector import caliban.introspection.adt._ import caliban.parsing.adt.ExecutableDefinition.{ FragmentDefinition, OperationDefinition } +import caliban.parsing.adt.OperationType._ import caliban.parsing.adt.Selection.{ Field, FragmentSpread, InlineFragment } import caliban.parsing.adt.Type.NamedType -import caliban.InputValue.VariableValue -import caliban.Value.NullValue -import caliban.execution.{ Field => F } import caliban.parsing.adt.{ Directive, Document, OperationType, Selection, Type } -import caliban.schema.{ RootType, Types } +import caliban.schema.{ RootSchema, RootType, Types } +import caliban.{ InputValue, Rendering } import zio.IO object Validator { @@ -27,7 +28,67 @@ object Validator { /** * Verifies that the given document is valid for this type. Fails with a [[caliban.CalibanError.ValidationError]] otherwise. */ - def validate(document: Document, rootType: RootType): IO[ValidationError, Unit] = { + def validate(document: Document, rootType: RootType): IO[ValidationError, Unit] = + check(document, rootType).unit + + /** + * Prepare the request for execution. + * Fails with a [[caliban.CalibanError.ValidationError]] otherwise. + */ + def prepare[R]( + document: Document, + rootType: RootType, + rootSchema: RootSchema[R], + operationName: Option[String], + variables: Map[String, InputValue], + skipValidation: Boolean + ): IO[ValidationError, ExecutionRequest] = { + val fragments = if (skipValidation) { + IO.succeed(collectOperationsAndFragments(document)._2.foldLeft(Map.empty[String, FragmentDefinition]) { + case (m, f) => m.updated(f.name, f) + }) + } else check(document, rootType) + + fragments.flatMap { fragments => + val operation = operationName match { + case Some(name) => + document.definitions.collectFirst { case op: OperationDefinition if op.name.contains(name) => op } + .toRight(s"Unknown operation $name.") + case None => + document.definitions.collect { case op: OperationDefinition => op } match { + case head :: Nil => Right(head) + case _ => Left("Operation name is required.") + } + } + + operation match { + case Left(error) => IO.fail(ValidationError(error, "")) + case Right(op) => + (op.operationType match { + case Query => IO.succeed(rootSchema.query) + case Mutation => + rootSchema.mutation match { + case Some(m) => IO.succeed(m) + case None => IO.fail(ValidationError("Mutations are not supported on this schema", "")) + } + case Subscription => + rootSchema.subscription match { + case Some(m) => IO.succeed(m) + case None => IO.fail(ValidationError("Subscriptions are not supported on this schema", "")) + } + }).map( + operation => + ExecutionRequest( + F(op.selectionSet, fragments, variables, operation.opType), + op.operationType, + op.variableDefinitions + ) + ) + } + } + } + + private def check(document: Document, rootType: RootType): IO[ValidationError, Map[String, FragmentDefinition]] = { val (operations, fragments) = collectOperationsAndFragments(document) for { fragmentMap <- validateFragments(fragments) @@ -40,7 +101,7 @@ object Validator { _ <- validateVariables(context) _ <- validateSubscriptionOperation(context) _ <- validateDocumentFields(context) - } yield () + } yield fragmentMap } private def collectOperationsAndFragments(document: Document): (List[OperationDefinition], List[FragmentDefinition]) = diff --git a/core/src/main/scala/caliban/wrappers/ApolloTracing.scala b/core/src/main/scala/caliban/wrappers/ApolloTracing.scala new file mode 100644 index 000000000..4f568d412 --- /dev/null +++ b/core/src/main/scala/caliban/wrappers/ApolloTracing.scala @@ -0,0 +1,190 @@ +package caliban.wrappers + +import java.time.format.DateTimeFormatter +import java.time.{ Instant, ZoneId } +import java.util.concurrent.TimeUnit +import caliban.ResponseValue.{ ListValue, ObjectValue } +import caliban.Value.{ IntValue, StringValue } +import caliban.wrappers.Wrapper.{ EffectfulWrapper, FieldWrapper, OverallWrapper, ParsingWrapper, ValidationWrapper } +import caliban.{ Rendering, ResponseValue } +import zio.{ clock, FiberRef } +import zio.clock.Clock +import zio.duration.Duration +import zquery.ZQuery + +object ApolloTracing { + + /** + * Returns a wrapper that adds tracing information to every response + * following Apollo Tracing format: https://github.com/apollographql/apollo-tracing. + */ + val apolloTracing: EffectfulWrapper[Clock] = + EffectfulWrapper( + FiberRef + .make(Tracing()) + .map( + ref => + apolloTracingOverall(ref) |+| + apolloTracingParsing(ref) |+| + apolloTracingValidation(ref) |+| + apolloTracingField(ref) + ) + ) + + private val dateFormatter: DateTimeFormatter = DateTimeFormatter + .ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") + .withZone(ZoneId.of("UTC")) + + case class Parsing(startOffset: Long = 0, duration: Duration = Duration.Zero) { + def toResponseValue: ResponseValue = + ObjectValue(List("startOffset" -> IntValue(startOffset), "duration" -> IntValue(duration.toNanos))) + } + + case class Validation(startOffset: Long = 0, duration: Duration = Duration.Zero) { + def toResponseValue: ResponseValue = + ObjectValue(List("startOffset" -> IntValue(startOffset), "duration" -> IntValue(duration.toNanos))) + } + + case class Resolver( + path: List[Either[String, Int]] = Nil, + parentType: String = "", + fieldName: String = "", + returnType: String = "", + startOffset: Long = 0, + duration: Duration = Duration.Zero + ) { + def toResponseValue: ResponseValue = + ObjectValue( + List( + "path" -> ListValue((Left(fieldName) :: path).reverse.map { + case Left(s) => StringValue(s) + case Right(i) => IntValue(i) + }), + "parentType" -> StringValue(parentType), + "fieldName" -> StringValue(fieldName), + "returnType" -> StringValue(returnType), + "startOffset" -> IntValue(startOffset), + "duration" -> IntValue(duration.toNanos) + ) + ) + } + + case class Execution(resolvers: List[Resolver] = Nil) { + def toResponseValue: ResponseValue = + ObjectValue(List("resolvers" -> ListValue(resolvers.sortBy(_.startOffset).map(_.toResponseValue)))) + } + + case class Tracing( + version: Int = 1, + startTime: Long = 0, + endTime: Long = 0, + startTimeMonotonic: Long = 0, + duration: Duration = Duration.Zero, + parsing: Parsing = Parsing(), + validation: Validation = Validation(), + execution: Execution = Execution() + ) { + def toResponseValue: ResponseValue = + ObjectValue( + List( + "version" -> IntValue(version), + "startTime" -> StringValue(dateFormatter.format(Instant.ofEpochMilli(startTime))), + "endTime" -> StringValue(dateFormatter.format(Instant.ofEpochMilli(endTime))), + "duration" -> IntValue(duration.toNanos), + "parsing" -> parsing.toResponseValue, + "validation" -> validation.toResponseValue, + "execution" -> execution.toResponseValue + ) + ) + } + + private def apolloTracingOverall(ref: FiberRef[Tracing]): OverallWrapper[Clock] = + OverallWrapper { + case (io, _) => + for { + nanoTime <- clock.nanoTime + currentTime <- clock.currentTime(TimeUnit.MILLISECONDS) + _ <- ref.update(_.copy(startTime = currentTime, startTimeMonotonic = nanoTime)) + result <- io.timed.flatMap { + case (duration, result) => + for { + endTime <- clock.currentTime(TimeUnit.MILLISECONDS) + _ <- ref.update(_.copy(duration = duration, endTime = endTime)) + tracing <- ref.get + } yield result.copy( + extensions = Some( + ObjectValue( + ("tracing" -> tracing.toResponseValue) :: + result.extensions.fold(List.empty[(String, ResponseValue)])(_.fields) + ) + ) + ) + } + } yield result + } + + private def apolloTracingParsing(ref: FiberRef[Tracing]): ParsingWrapper[Clock] = + ParsingWrapper { + case (io, _) => + for { + start <- clock.nanoTime + (duration, result) <- io.timed + _ <- ref.update( + state => + state.copy( + parsing = state.parsing.copy(startOffset = start - state.startTimeMonotonic, duration = duration) + ) + ) + } yield result + } + + private def apolloTracingValidation(ref: FiberRef[Tracing]): ValidationWrapper[Clock] = + ValidationWrapper { + case (io, _) => + for { + start <- clock.nanoTime + (duration, result) <- io.timed + _ <- ref.update( + state => + state.copy( + validation = + state.validation.copy(startOffset = start - state.startTimeMonotonic, duration = duration) + ) + ) + } yield result + } + + private def apolloTracingField(ref: FiberRef[Tracing]): FieldWrapper[Clock] = + FieldWrapper( + { + case (query, fieldInfo) => + for { + start <- ZQuery.fromEffect(clock.nanoTime) + result <- query + end <- ZQuery.fromEffect(clock.nanoTime) + duration = Duration.fromNanos(end - start) + _ <- ZQuery.fromEffect( + ref + .update( + state => + state.copy( + execution = state.execution.copy( + resolvers = + Resolver( + path = fieldInfo.path, + parentType = fieldInfo.parentType.fold("")(Rendering.renderTypeName), + fieldName = fieldInfo.fieldName, + returnType = Rendering.renderTypeName(fieldInfo.returnType), + startOffset = start - state.startTimeMonotonic, + duration = duration + ) :: state.execution.resolvers + ) + ) + ) + ) + } yield result + }, + wrapPureValues = true + ) + +} diff --git a/core/src/main/scala/caliban/wrappers/Wrapper.scala b/core/src/main/scala/caliban/wrappers/Wrapper.scala new file mode 100644 index 000000000..ceb9803b7 --- /dev/null +++ b/core/src/main/scala/caliban/wrappers/Wrapper.scala @@ -0,0 +1,133 @@ +package caliban.wrappers + +import scala.annotation.tailrec +import caliban.CalibanError.{ ParsingError, ValidationError } +import caliban.execution.{ ExecutionRequest, FieldInfo } +import caliban.parsing.adt.Document +import caliban.wrappers.Wrapper.CombinedWrapper +import caliban.{ CalibanError, GraphQLResponse, ResponseValue } +import zio.{ UIO, ZIO } +import zquery.ZQuery + +/** + * A `Wrapper[-R]` represents an extra layer of computation that can be applied on top of Caliban's query handling. + * There are different base types of wrappers: + * - `OverallWrapper` to wrap the whole query processing + * - `ParsingWrapper` to wrap the query parsing only + * - `ValidationWrapper` to wrap the query validation only + * - `ExecutionWrapper` to wrap the query execution only + * - `FieldWrapper` to wrap each field execution + * + * It is also possible to combine wrappers using `|+|` and to build a wrapper effectfully with `EffectfulWrapper`. + */ +sealed trait Wrapper[-R] { self => + def |+|[R1 <: R](that: Wrapper[R1]): Wrapper[R1] = CombinedWrapper(List(self, that)) +} + +object Wrapper { + + /** + * `WrappingFunction[R, E, A, Info]` is an alias for a function that takes an `ZIO[R, E, A]` and some extra `Info` + * and returns a `ZIO[R, E, A]`. + */ + type WrappingFunction[R, E, A, Info] = (ZIO[R, E, A], Info) => ZIO[R, E, A] + + /** + * Wrapper for the whole query processing. + * Takes a function from a `UIO[GraphQLResponse[CalibanError]]` and a query `String` and that returns a + * `URIO[R, GraphQLResponse[CalibanError]]`. + */ + case class OverallWrapper[R](f: WrappingFunction[R, Nothing, GraphQLResponse[CalibanError], String]) + extends Wrapper[R] + + /** + * Wrapper for the query parsing stage. + * Takes a function from an `IO[ParsingError, Document]` and a query `String` and that returns a + * `ZIO[R, ParsingError, Document]`. + */ + case class ParsingWrapper[R](f: WrappingFunction[R, ParsingError, Document, String]) extends Wrapper[R] + + /** + * Wrapper for the query validation stage. + * Takes a function from an `IO[ValidationError, ExecutionRequest]` and a `Document` and that returns a + * `ZIO[R, ValidationError, ExecutionRequest]`. + */ + case class ValidationWrapper[R](f: WrappingFunction[R, ValidationError, ExecutionRequest, Document]) + extends Wrapper[R] + + /** + * Wrapper for the query execution stage. + * Takes a function from a `UIO[GraphQLResponse[CalibanError]]` and an `ExecutionRequest` and that returns a + * `URIO[R, GraphQLResponse[CalibanError]]`. + */ + case class ExecutionWrapper[R](f: WrappingFunction[R, Nothing, GraphQLResponse[CalibanError], ExecutionRequest]) + extends Wrapper[R] + + /** + * Wrapper for each individual field. + * Takes a function from a `ZQuery[Any, Nothing, ResponseValue]` and a `FieldInfo` and that returns a + * `ZQuery[R, CalibanError, ResponseValue]`. + * If `wrapPureValues` is true, every single field will be wrapped, which could have an impact on performances. + * If false, simple pure values will be ignored. + */ + case class FieldWrapper[R]( + f: (ZQuery[R, Nothing, ResponseValue], FieldInfo) => ZQuery[R, CalibanError, ResponseValue], + wrapPureValues: Boolean = false + ) extends Wrapper[R] + + /** + * Wrapper that combines multiple wrappers. + * @param wrappers a list of wrappers + */ + case class CombinedWrapper[-R](wrappers: List[Wrapper[R]]) extends Wrapper[R] + + /** + * A wrapper that requires an effect to be built. The effect will be run for each query. + * @param wrapper an effect that builds a wrapper + */ + case class EffectfulWrapper[-R](wrapper: UIO[Wrapper[R]]) extends Wrapper[R] + + private[caliban] def wrap[R1 >: R, R, E, A, Info]( + zio: ZIO[R1, E, A] + )(wrappers: List[WrappingFunction[R, E, A, Info]], info: Info): ZIO[R, E, A] = { + @tailrec + def loop(zio: ZIO[R, E, A], wrappers: List[WrappingFunction[R, E, A, Info]]): ZIO[R, E, A] = + wrappers match { + case Nil => zio + case wrapper :: tail => loop(wrapper(zio, info), tail) + } + loop(zio, wrappers) + } + + private[caliban] def decompose[R](wrappers: List[Wrapper[R]]): UIO[ + ( + List[WrappingFunction[R, Nothing, GraphQLResponse[CalibanError], String]], + List[WrappingFunction[R, ParsingError, Document, String]], + List[WrappingFunction[R, ValidationError, ExecutionRequest, Document]], + List[WrappingFunction[R, Nothing, GraphQLResponse[CalibanError], ExecutionRequest]], + List[FieldWrapper[R]] + ) + ] = + ZIO.foldLeft(wrappers)( + ( + List.empty[WrappingFunction[R, Nothing, GraphQLResponse[CalibanError], String]], + List.empty[WrappingFunction[R, ParsingError, Document, String]], + List.empty[WrappingFunction[R, ValidationError, ExecutionRequest, Document]], + List.empty[WrappingFunction[R, Nothing, GraphQLResponse[CalibanError], ExecutionRequest]], + List.empty[FieldWrapper[R]] + ) + ) { + case ((o, p, v, e, f), wrapper: OverallWrapper[R]) => UIO.succeed((wrapper.f :: o, p, v, e, f)) + case ((o, p, v, e, f), wrapper: ParsingWrapper[R]) => UIO.succeed((o, wrapper.f :: p, v, e, f)) + case ((o, p, v, e, f), wrapper: ValidationWrapper[R]) => UIO.succeed((o, p, wrapper.f :: v, e, f)) + case ((o, p, v, e, f), wrapper: ExecutionWrapper[R]) => UIO.succeed((o, p, v, wrapper.f :: e, f)) + case ((o, p, v, e, f), wrapper: FieldWrapper[R]) => UIO.succeed((o, p, v, e, wrapper :: f)) + case ((o, p, v, e, f), CombinedWrapper(wrappers)) => + decompose(wrappers).map { case (o2, p2, v2, e2, f2) => (o2 ++ o, p2 ++ p, v2 ++ v, e2 ++ e, f2 ++ f) } + case ((o, p, v, e, f), EffectfulWrapper(wrapper)) => + wrapper.flatMap( + w => decompose(List(w)).map { case (o2, p2, v2, e2, f2) => (o2 ++ o, p2 ++ p, v2 ++ v, e2 ++ e, f2 ++ f) } + ) + } + +} diff --git a/core/src/main/scala/caliban/wrappers/Wrappers.scala b/core/src/main/scala/caliban/wrappers/Wrappers.scala new file mode 100644 index 000000000..738929817 --- /dev/null +++ b/core/src/main/scala/caliban/wrappers/Wrappers.scala @@ -0,0 +1,94 @@ +package caliban.wrappers + +import caliban.CalibanError.{ ExecutionError, ValidationError } +import caliban.GraphQLResponse +import caliban.Value.NullValue +import caliban.execution.Field +import caliban.wrappers.Wrapper.{ OverallWrapper, ValidationWrapper } +import zio.clock.Clock +import zio.console.{ putStrLn, Console } +import zio.duration.Duration +import zio.{ IO, URIO, ZIO } + +object Wrappers { + + /** + * Returns a wrapper that prints slow queries + * @param duration threshold above which queries are considered slow + */ + def printSlowQueries(duration: Duration): OverallWrapper[Console with Clock] = + onSlowQueries(duration) { case (time, query) => putStrLn(s"Slow query took ${time.render}:\n$query") } + + /** + * Returns a wrapper that runs a given function in case of slow queries + * @param duration threshold above which queries are considered slow + */ + def onSlowQueries[R](duration: Duration)(f: (Duration, String) => URIO[R, Any]): OverallWrapper[R with Clock] = + OverallWrapper { + case (io, query) => + io.timed.flatMap { + case (time, res) => + ZIO.when(time > duration)(f(time, query)).as(res) + } + } + + /** + * Returns a wrapper that times out queries taking more than a specified time. + * @param duration threshold above which queries should be timed out + */ + def timeout(duration: Duration): OverallWrapper[Clock] = + OverallWrapper { + case (io, query) => + io.timeout(duration) + .map( + _.getOrElse( + GraphQLResponse( + NullValue, + List(ExecutionError(s"Query was interrupted after timeout of ${duration.render}:\n$query")) + ) + ) + ) + } + + /** + * Returns a wrapper that checks that the query's depth is under a given max + * @param maxDepth the max allowed depth + */ + def maxDepth(maxDepth: Int): ValidationWrapper[Any] = + ValidationWrapper { + case (io, _) => + io.flatMap { req => + val depth = calculateDepth(req.field) + if (depth > maxDepth) IO.fail(ValidationError(s"Query is too deep: $depth. Max depth: $maxDepth.", "")) + else IO.succeed(req) + } + } + + private def calculateDepth(field: Field): Int = { + val children = field.fields ++ field.conditionalFields.values.flatten + val childrenDepth = if (children.isEmpty) 0 else children.map(calculateDepth).max + childrenDepth + (if (field.name.nonEmpty) 1 else 0) + } + + /** + * Returns a wrapper that checks that the query has a limited number of fields + * @param maxFields the max allowed number of fields + */ + def maxFields(maxFields: Int): ValidationWrapper[Any] = + ValidationWrapper { + case (io, _) => + io.flatMap { req => + val fields = countFields(req.field) + if (fields > maxFields) + IO.fail(ValidationError(s"Query has too many fields: $fields. Max fields: $maxFields.", "")) + else IO.succeed(req) + } + } + + private def countFields(field: Field): Int = + innerFields(field.fields) + (if (field.conditionalFields.isEmpty) 0 + else field.conditionalFields.values.map(innerFields).max) + + private def innerFields(fields: List[Field]): Int = fields.length + fields.map(countFields).sum + +} diff --git a/core/src/main/scala/zquery/ZQuery.scala b/core/src/main/scala/zquery/ZQuery.scala index 6b8692e2c..202b7761e 100644 --- a/core/src/main/scala/zquery/ZQuery.scala +++ b/core/src/main/scala/zquery/ZQuery.scala @@ -286,6 +286,11 @@ object ZQuery { def collectAllPar[R, E, A](as: Iterable[ZQuery[R, E, A]]): ZQuery[R, E, List[A]] = foreachPar(as)(identity) + /** + * Accesses the whole environment of the query. + */ + def environment[R]: ZQuery[R, Nothing, R] = ZQuery.fromEffect(ZIO.environment[R]) + /** * Constructs a query that fails with the specified error. */ diff --git a/core/src/test/scala/caliban/execution/ExecutionSpec.scala b/core/src/test/scala/caliban/execution/ExecutionSpec.scala index 8e3a88cbb..f14571529 100644 --- a/core/src/test/scala/caliban/execution/ExecutionSpec.scala +++ b/core/src/test/scala/caliban/execution/ExecutionSpec.scala @@ -1,12 +1,13 @@ package caliban.execution import java.util.UUID -import caliban.CalibanError.ValidationError +import caliban.CalibanError.ExecutionError import caliban.GraphQL._ import caliban.Macros.gqldoc import caliban.RootResolver import caliban.TestUtils._ import caliban.Value.{ BooleanValue, StringValue } +import zio.IO import zio.test.Assertion._ import zio.test._ @@ -248,52 +249,28 @@ object ExecutionSpec } yield assert(result.errors, equalTo(List("my custom error"))) && assert(result.asJson.noSpaces, equalTo("""{"data":null,"errors":[{"message":"my custom error"}]}""")) }, - testM("QueryAnalyzer > fields") { - case class A(b: B) - case class B(c: Int) - case class Test(a: A) - val interpreter = QueryAnalyzer.maxFields(2)(graphQL(RootResolver(Test(A(B(2)))))).interpreter - val query = gqldoc(""" - { - a { - b { - c - } - } - }""") - assertM( - interpreter.execute(query).map(_.errors), - equalTo(List(ValidationError("Query has too many fields: 3. Max fields: 2.", ""))) - ) - }, - testM("QueryAnalyzer > fields with fragment") { - case class A(b: B) - case class B(c: Int) - case class Test(a: A) - val interpreter = QueryAnalyzer.maxFields(2)(graphQL(RootResolver(Test(A(B(2)))))).interpreter - val query = gqldoc(""" - query test { - a { - ...f - } - } - - fragment f on A { - b { - c - } - } - """) + testM("merge 2 APIs") { + case class Test(name: String) + case class Test2(id: Int) + val api1 = graphQL(RootResolver(Test("name"))) + val api2 = graphQL(RootResolver(Test2(2))) + val interpreter = (api1 |+| api2).interpreter + val query = + """query{ + | name + | id + |}""".stripMargin assertM( - interpreter.execute(query).map(_.errors), - equalTo(List(ValidationError("Query has too many fields: 3. Max fields: 2.", ""))) + interpreter.execute(query).map(_.data.toString), + equalTo("""{"name":"name","id":2}""") ) }, - testM("QueryAnalyzer > depth") { + testM("error path") { case class A(b: B) - case class B(c: Int) + case class B(c: IO[Throwable, Int]) case class Test(a: A) - val interpreter = QueryAnalyzer.maxDepth(2)(graphQL(RootResolver(Test(A(B(2)))))).interpreter + val e = new Exception("boom") + val interpreter = graphQL(RootResolver(Test(A(B(IO.fail(e)))))).interpreter val query = gqldoc(""" { a { @@ -304,23 +281,7 @@ object ExecutionSpec }""") assertM( interpreter.execute(query).map(_.errors), - equalTo(List(ValidationError("Query is too deep: 3. Max depth: 2.", ""))) - ) - }, - testM("merge 2 APIs") { - case class Test(name: String) - case class Test2(id: Int) - val api1 = graphQL(RootResolver(Test("name"))) - val api2 = graphQL(RootResolver(Test2(2))) - val interpreter = (api1 |+| api2).interpreter - val query = - """query{ - | name - | id - |}""".stripMargin - assertM( - interpreter.execute(query).map(_.data.toString), - equalTo("""{"name":"name","id":2}""") + equalTo(List(ExecutionError("Effect failure", List(Left("a"), Left("b"), Left("c")), Some(e)))) ) } ) diff --git a/core/src/test/scala/caliban/wrappers/WrappersSpec.scala b/core/src/test/scala/caliban/wrappers/WrappersSpec.scala new file mode 100644 index 000000000..a48617012 --- /dev/null +++ b/core/src/test/scala/caliban/wrappers/WrappersSpec.scala @@ -0,0 +1,148 @@ +package caliban.wrappers + +import scala.language.postfixOps +import caliban.CalibanError.{ ExecutionError, ValidationError } +import caliban.GraphQL._ +import caliban.Macros.gqldoc +import caliban.schema.GenericSchema +import caliban.wrappers.Wrappers._ +import caliban.{ CalibanError, GraphQLInterpreter, RootResolver } +import zio.clock.Clock +import zio.duration._ +import zio.test.Assertion._ +import zio.test._ +import zio.test.environment.TestClock +import zio.{ clock, Promise, URIO, ZIO } + +object WrappersSpec + extends DefaultRunnableSpec( + suite("WrappersSpec")( + testM("Max fields") { + case class A(b: B) + case class B(c: Int) + case class Test(a: A) + val interpreter = (graphQL(RootResolver(Test(A(B(2))))) @@ maxFields(2)).interpreter + val query = gqldoc(""" + { + a { + b { + c + } + } + }""") + assertM( + interpreter.execute(query).map(_.errors), + equalTo(List(ValidationError("Query has too many fields: 3. Max fields: 2.", ""))) + ) + }, + testM("Max fields with fragment") { + case class A(b: B) + case class B(c: Int) + case class Test(a: A) + val interpreter = (graphQL(RootResolver(Test(A(B(2))))) @@ maxFields(2)).interpreter + val query = gqldoc(""" + query test { + a { + ...f + } + } + + fragment f on A { + b { + c + } + } + """) + assertM( + interpreter.execute(query).map(_.errors), + equalTo(List(ValidationError("Query has too many fields: 3. Max fields: 2.", ""))) + ) + }, + testM("Max depth") { + case class A(b: B) + case class B(c: Int) + case class Test(a: A) + val interpreter = (graphQL(RootResolver(Test(A(B(2))))) @@ maxDepth(2)).interpreter + val query = gqldoc(""" + { + a { + b { + c + } + } + }""") + assertM( + interpreter.execute(query).map(_.errors), + equalTo(List(ValidationError("Query is too deep: 3. Max depth: 2.", ""))) + ) + }, + testM("Timeout") { + case class Test(a: URIO[Clock, Int]) + + object schema extends GenericSchema[Clock] + import schema._ + + val interpreter = + (graphQL(RootResolver(Test(clock.sleep(2 minutes).as(0)))) @@ timeout(1 minute)).interpreter + val query = gqldoc(""" + { + a + }""") + assertM( + TestClock.adjust(1 minute) *> interpreter.execute(query).map(_.errors), + equalTo(List(ExecutionError("""Query was interrupted after timeout of 1 m: + + { + a + }""".stripMargin))) + ) + }, + testM("Apollo Tracing") { + case class Query(hero: Hero) + case class Hero(name: URIO[Clock, String], friends: List[Hero] = Nil) + + object schema extends GenericSchema[Clock] + import schema._ + + def interpreter(latch: Promise[Nothing, Unit]): GraphQLInterpreter[Clock, CalibanError] = + (graphQL( + RootResolver( + Query( + Hero( + latch.succeed(()) *> ZIO.sleep(1 second).as("R2-D2"), + List( + Hero(ZIO.succeed("Luke Skywalker")), + Hero(ZIO.succeed("Han Solo")), + Hero(ZIO.succeed("Leia Organa")) + ) + ) + ) + ) + ) @@ ApolloTracing.apolloTracing).interpreter + + val query = gqldoc(""" + { + hero { + name + friends { + name + } + } + }""") + assertM( + for { + latch <- Promise.make[Nothing, Unit] + fiber <- interpreter(latch).execute(query).map(_.extensions.map(_.toString)).fork + _ <- latch.await + _ <- TestClock.adjust(1 second) + result <- fiber.join + } yield result, + isSome( + equalTo( + """{"tracing":{"version":1,"startTime":"1970-01-01T00:00:00.000Z","endTime":"1970-01-01T00:00:01.000Z","duration":1000000000,"parsing":{"startOffset":0,"duration":0},"validation":{"startOffset":0,"duration":0},"execution":{"resolvers":[{"path":["hero"],"parentType":"Query","fieldName":"hero","returnType":"Hero!","startOffset":0,"duration":1000000000},{"path":["hero","name"],"parentType":"Hero","fieldName":"name","returnType":"String!","startOffset":0,"duration":1000000000},{"path":["hero","friends"],"parentType":"Hero","fieldName":"friends","returnType":"[Hero!]!","startOffset":1000000000,"duration":0},{"path":["hero","friends",2,"name"],"parentType":"Hero","fieldName":"name","returnType":"String!","startOffset":1000000000,"duration":0},{"path":["hero","friends",1,"name"],"parentType":"Hero","fieldName":"name","returnType":"String!","startOffset":1000000000,"duration":0},{"path":["hero","friends",0,"name"],"parentType":"Hero","fieldName":"name","returnType":"String!","startOffset":1000000000,"duration":0}]}}}""" + ) + ) + ) + } + ) + ) diff --git a/examples/src/main/scala/caliban/http4s/ExampleApp.scala b/examples/src/main/scala/caliban/http4s/ExampleApp.scala index 2593881a7..2c5a63572 100644 --- a/examples/src/main/scala/caliban/http4s/ExampleApp.scala +++ b/examples/src/main/scala/caliban/http4s/ExampleApp.scala @@ -1,10 +1,12 @@ package caliban.http4s +import scala.language.postfixOps import caliban.ExampleData._ import caliban.GraphQL._ -import caliban.execution.QueryAnalyzer._ import caliban.schema.Annotations.{ GQLDeprecated, GQLDescription } import caliban.schema.GenericSchema +import caliban.wrappers.ApolloTracing.apolloTracing +import caliban.wrappers.Wrappers._ import caliban.{ ExampleService, GraphQL, Http4sAdapter, RootResolver } import cats.data.Kleisli import cats.effect.Blocker @@ -17,6 +19,7 @@ import zio._ import zio.blocking.Blocking import zio.clock.Clock import zio.console.{ putStrLn, Console } +import zio.duration._ import zio.interop.catz._ import zio.stream.ZStream @@ -39,20 +42,21 @@ object ExampleApp extends CatsApp with GenericSchema[Console with Clock] { implicit val charactersArgsSchema = gen[CharactersArgs] def makeApi(service: ExampleService): GraphQL[Console with Clock] = - maxDepth(30)( - maxFields(200)( - graphQL( - RootResolver( - Queries( - args => service.getCharacters(args.origin), - args => service.findCharacter(args.name) - ), - Mutations(args => service.deleteCharacter(args.name)), - Subscriptions(service.deletedEvents) - ) - ) + graphQL( + RootResolver( + Queries( + args => service.getCharacters(args.origin), + args => service.findCharacter(args.name) + ), + Mutations(args => service.deleteCharacter(args.name)), + Subscriptions(service.deletedEvents) ) - ) + ) @@ + maxFields(200) @@ // query analyzer that limit query fields + maxDepth(30) @@ // query analyzer that limit query depth + timeout(3 seconds) @@ // wrapper that fails slow queries + printSlowQueries(500 millis) @@ // wrapper that logs slow queries + apolloTracing // wrapper for https://github.com/apollographql/apollo-tracing override def run(args: List[String]): ZIO[ZEnv, Nothing, Int] = (for { diff --git a/vuepress/docs/docs/middleware.md b/vuepress/docs/docs/middleware.md index 8736b5ea4..05ce74c2b 100644 --- a/vuepress/docs/docs/middleware.md +++ b/vuepress/docs/docs/middleware.md @@ -1,76 +1,93 @@ # Middleware -You might want to perform some actions on every query received. Caliban supports this in 2 different ways: +Caliban allows you to perform additional actions at various level of a query processing, via the concept of `Wrapper`. Using wrappers, you can: +- verify that a query doesn't reach some limit (e.g. depth, complexity) +- modify a query before it's executed +- add timeouts to queries or fields +- log each field execution time +- support [Apollo Tracing](https://github.com/apollographql/apollo-tracing) or anything similar +- etc. -- you can wrap the execution of each query with any arbitrary code -- you can analyze the requested fields before execution, and possibly modify or reject the query +## Wrapper types -## Wrapping query execution +There are 5 basic types of wrappers: + - `OverallWrapper` to wrap the whole query processing + - `ParsingWrapper` to wrap the query parsing only + - `ValidationWrapper` to wrap the query validation only + - `ExecutionWrapper` to wrap the query execution only + - `FieldWrapper` to wrap each field execution -Once you have a `GraphQL` interpreter, you can call `wrapExecutionWith` on it. This method takes in a function `f` and returns a new `GraphQL` interpreter that will wrap the `execute` method with this function `f`. +Each one requires a function that takes a `ZIO` or `ZQuery` computation together with some contextual information (e.g. the query string) and should return another computation. -It is used internally to implement `mapError` (customize errors) and `provide` (eliminate the environment), but you can use it for other purposes such as adding a general timeout, logging response times, etc. +Let's see how to implement a wrapper that times out the whole query if its processing takes longer that 1 minute. ```scala -// create an interpreter -val i: GraphQLInterpreter[MyEnv, CalibanError] = graphqQL(...).interpreter - -// change error type to String -val i2: GraphQLInterpreter[MyEnv, String] = i.mapError(_.toString) - -// provide the environment -val i3: GraphQLInterpreter[Any, CalibanError] = i.provide(myEnv) - -// add a timeout on every query execution -val i4: GraphQLInterpreter[MyEnv with Clock, CalibanError] = - i.wrapExecutionWith( - _.timeout(30 seconds).map( - _.getOrElse(GraphQLResponse(NullValue, List(ExecutionError("Timeout!")))) - ) - ) +val wrapper = OverallWrapper { + case (io, query) => + io.timeout(1 minute) + .map( + _.getOrElse( + GraphQLResponse( + NullValue, + List(ExecutionError(s"Query was interrupted after timeout of ${duration.render}:\n$query")) + ) + ) + ) +} ``` -## Query Analyzer - -You can also use `GraphQL#withQueryAnalyzer` to register a hook function that will be run before query execution. Such function is called a `QueryAnalyzer` and looks like this: +You can also combine wrappers using `|+|` and create a wrapper that requires an effect to be run at each query using `EffectfulWrapper`. +To use your wrapper, call `GraphQL#withWrapper` or its alias `@@`. ```scala -type QueryAnalyzer[-R] = Field => ZIO[R, CalibanError, Field] +val api = graphQL(...).withWrapper(wrapper) +// or +val api = graphQL(...) @@ wrapper ``` -As an input, you receive the root `Field` object that contains the whole query in a convenient format (the fragments are replaced by the actual fields). You can analyze this object and return a `ZIO[R, CalibanError, Field]` that allows you to: +## Pre-defined wrappers + -- modify the query (e.g. add or remove fields) -- return an error to prevent execution -- run some effect (e.g. write metrics somewhere) +Caliban comes with a few pre-made wrappers in `caliban.wrappers.Wrappers`: +- `maxDepth` returns a wrapper that fails queries whose depth is higher than a given value +- `maxFields` returns a wrapper that fails queries whose number of fields is higher than a given value +- `timeout` returns a wrapper that fails queries taking more than a specified time +- `printSlowQueries` returns a wrapper that prints slow queries +- `onSlowQueries` returns a wrapper that can run a given function on slow queries -A typical use case is to limit the number of fields or the depth of the query. Those are already implemented in Caliban and can be used like this: +In addition to those, `caliban.wrappers.ApolloTracing.apolloTracing` returns a wrapper that adds tracing data into the `extensions` field of each response following [Apollo Tracing](https://github.com/apollographql/apollo-tracing) format. +They can be used like this: ```scala -val interpreter = - maxDepth(30)( - maxFields(200)( - graphQL(...) - ) - ) +val api = + graphQL(...) @@ + maxDepth(50) @@ + timeout(3 seconds) @@ + printSlowQueries(500 millis) @@ + apolloTracing ``` -You can look at their implementation which is only a few lines long. +## Wrapping the interpreter -You can even use a Query Analyzer to "inject" some data into your ZIO environment and possibly use it later during execution. For example, you can have a `Context` as part of your environment: +All the wrappers mentioned above require that you don't modify the environment `R` and the error type which is always a `CalibanError`. It is also possible to wrap your `GraphQLInterpreter` by calling `wrapExecutionWith` on it. This method takes in a function `f` and returns a new `GraphQLInterpreter` that will wrap the `execute` method with this function `f`. + +It is used internally to implement `mapError` (customize errors) and `provide` (eliminate the environment), but you can use it for other purposes such as adding a general timeout, logging response times, etc. ```scala -case class ContextData(cost: Int) -trait Context { - def context: Ref[ContextData] -} -``` +// create an interpreter +val i: GraphQLInterpreter[MyEnv, CalibanError] = graphqQL(...).interpreter -Then you can register a Query Analyzer that calculates the cost of every query (in this example, I just used the number of fields) and save it into the context. that way you can support an API that returns the cost associated with each query. +// change error type to String +val i2: GraphQLInterpreter[MyEnv, String] = i.mapError(_.toString) -```scala -interpreter.withQueryAnalyzer { root => - val cost = QueryAnalyzer.countFields(root) - ZIO.accessM[Context](_.context.update(_.copy(cost = cost))).as(root) -} -``` +// provide the environment +val i3: GraphQLInterpreter[Any, CalibanError] = i.provide(myEnv) + +// add a timeout on every query execution +val i4: GraphQLInterpreter[MyEnv with Clock, CalibanError] = + i.wrapExecutionWith( + _.timeout(30 seconds).map( + _.getOrElse(GraphQLResponse(NullValue, List(ExecutionError("Timeout!")))) + ) + ) +``` \ No newline at end of file