forked from ghostdogpr/caliban
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add cost estimation wrapper (ghostdogpr#1180)
* Add cost estimation wrapper
- Loading branch information
1 parent
5fb931f
commit 637ed04
Showing
4 changed files
with
503 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
284 changes: 284 additions & 0 deletions
284
core/src/main/scala/caliban/wrappers/CostEstimation.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
package caliban.wrappers | ||
|
||
import caliban.CalibanError.ValidationError | ||
import caliban.InputValue.ListValue | ||
import caliban.ResponseValue.ObjectValue | ||
import caliban.Value.{ FloatValue, IntValue, StringValue } | ||
import caliban.execution.{ ExecutionRequest, Field } | ||
import caliban.parsing.adt.{ Directive, Document } | ||
import caliban.schema.Annotations.GQLDirective | ||
import caliban.schema.Types | ||
import caliban.wrappers.Wrapper.{ EffectfulWrapper, OverallWrapper, ValidationWrapper } | ||
import caliban.{ CalibanError, GraphQLRequest, GraphQLResponse, ResponseValue } | ||
import zio.{ IO, Ref, UIO, URIO, ZIO } | ||
|
||
import scala.annotation.tailrec | ||
|
||
object CostEstimation { | ||
|
||
final val COST_DIRECTIVE_NAME = "cost" | ||
final val COST_EXTENSION_NAME = "queryCost" | ||
|
||
case class GQLCost(cost: Double, multipliers: List[String] = Nil) | ||
extends GQLDirective(CostDirective(cost, multipliers)) | ||
|
||
/** | ||
* A directive that can be applied to both fields and types to flag them as targets for cost analysis. | ||
* This allows a simple estimation to be applied which can be used to prevent overly expensive queries from being executed | ||
*/ | ||
object CostDirective { | ||
def apply(cost: Double): Directive = | ||
Directive(COST_DIRECTIVE_NAME, arguments = Map("weight" -> FloatValue(cost))) | ||
|
||
def apply(cost: Double, multipliers: List[String]): Directive = { | ||
val multi = multipliers match { | ||
case Nil => Map.empty | ||
case _ => Map("multipliers" -> ListValue(multipliers.map(StringValue.apply))) | ||
} | ||
Directive( | ||
COST_DIRECTIVE_NAME, | ||
arguments = Map("weight" -> FloatValue(cost)) ++ multi | ||
) | ||
} | ||
} | ||
|
||
/** | ||
* Computes field cost by examining the @cost directive. This can be used in conjunction with `queryCost` or [[maxCost]] | ||
* In order to compute the estimated cost of executing a query. | ||
* | ||
* @note This will be executed *before* the actual resolvers are called, which allows you to stop potentially expensive queries | ||
* from being run, but may also require more work on the developer part to determine the correct heuristic for estimating field cost. | ||
*/ | ||
val costDirective: Field => Double = (f: Field) => { | ||
def computeDirectiveCost(directives: List[Directive]) = | ||
directives.collectFirst { | ||
case d if d.name == COST_DIRECTIVE_NAME => | ||
val weight = d.arguments | ||
.get("weight") | ||
.collect { | ||
case i: IntValue => i.toInt.toDouble | ||
case f: FloatValue => f.toDouble | ||
} | ||
.getOrElse(1.0) | ||
|
||
val multipliers = d.arguments | ||
.get("multipliers") | ||
.toList | ||
.collect { case ListValue(values) => | ||
values.collect { | ||
case StringValue(name) => | ||
f.arguments.get(name).collectFirst { | ||
case i: IntValue => i.toInt.toDouble | ||
case d: FloatValue => d.toDouble | ||
} | ||
case _ => None | ||
}.flatten | ||
} | ||
.flatten | ||
.sum[Double] | ||
.max(1.0) | ||
|
||
weight * multipliers | ||
} | ||
|
||
val directiveCost = computeDirectiveCost(f.directives) | ||
|
||
val typeCost = Types | ||
.innerType(f.fieldType) | ||
.directives | ||
.flatMap(computeDirectiveCost) | ||
|
||
(directiveCost orElse typeCost).getOrElse(1.0) | ||
} | ||
|
||
/** | ||
* Computes the estimated cost of the query based on the default cost estimation (backed by the @GQLCost directive) and adds it as an extension to the GraphQLResponse. | ||
* This is useful for tracking the overall cost of a query either when you are trying to dial in a correct heuristic or when you | ||
* want to inform users of your graph how expensive their queries are. | ||
* | ||
* @see queryCostWith, queryCostZIO | ||
*/ | ||
lazy val queryCost: Wrapper[Any] = queryCost(costDirective) | ||
|
||
/** | ||
* Computes the estimated cost of the query based on the provided field cost function and adds it as an extension to the GraphQLResponse. | ||
* This is useful for tracking the overall cost of a query either when you are trying to dial in a correct heuristic or when you | ||
* want to inform users of your graph how expensive their queries are. | ||
* | ||
* @see queryCostWith, queryCostZIO | ||
*/ | ||
def queryCost(f: Field => Double): Wrapper[Any] = | ||
queryCostZIOWrapperState(costWrapper(_)(f))(addCostToExtensions) | ||
|
||
/** | ||
* Computes the estimated cost of the query based on the provided field cost function and passes it to the second function | ||
* which can run an arbitrary side effect with the result. | ||
* | ||
* @see queryCost | ||
*/ | ||
def queryCostWith[R](f: Field => Double)(p: Double => URIO[R, Any]): Wrapper[R] = | ||
queryCostZIOWrapperState[R](costWrapper(_)(f))((cost, r) => p(cost).as(r)) | ||
|
||
/** | ||
* A more powerful version of `queryCost` which allows the field cost computation to return an effect instead of a plain value | ||
* when computing the cost estimate. | ||
* @param f The field cost estimate function | ||
* | ||
* @see queryCostZIOWith for a more powerful version | ||
*/ | ||
def queryCostZIO[R](f: Field => URIO[R, Double]): Wrapper[R] = | ||
queryCostZIOWrapperState(costWrapperZIO(_)(f))(addCostToExtensions) | ||
|
||
/** | ||
* A more powerful version of [[queryCostZIO]] that allows the total result of the query to be pushed into a separate effect. | ||
* This is useful when you want to compute the cost of the query but you already have your own system for recording the cost. | ||
* @param f The field cost estimate function | ||
* @param p A function which receives the total estimated cost of executing the query and can | ||
*/ | ||
def queryCostZIOWith[R](f: Field => URIO[R, Double])(p: Double => URIO[R, Any]): Wrapper[R] = | ||
queryCostZIOWrapperState(costWrapperZIO(_)(f))((cost, r) => p(cost) as r) | ||
|
||
/** | ||
* The most powerful version of `queryCost` that allows maximum freedom in how the wrapper is defined. The function accepts | ||
* a function that will take a Ref that can be used for book keeping to track the current cost of the query, and returns a wrapper. | ||
* Normally this wrapper is a validation wrapper because it should be run before execution for the purposes of cost estimation. | ||
* However, the API provides the flexibility to redefine that. The second parameter of the function will receive the final cost estimate of the query | ||
* as well as the current graphql response, it returns an effect that produces a potentially modified response. | ||
* | ||
* @param costWrapper User provided function that will result in a wrapper that may be used when computing the cost of a query | ||
* @param p The response transforming function which receives the cost estimate as well as the response value. | ||
* | ||
* @see queryCostZIOWith, queryCostZIO, queryCost | ||
*/ | ||
def queryCostZIOWrapperState[R]( | ||
costWrapper: Ref[Double] => Wrapper[R] | ||
)( | ||
p: (Double, GraphQLResponse[CalibanError]) => URIO[R, GraphQLResponse[CalibanError]] | ||
): Wrapper[R] = EffectfulWrapper( | ||
Ref.make(0.0).map { cost => | ||
costWrapper(cost) |+| costOverall(resp => cost.get.flatMap(p(_, resp))) | ||
} | ||
) | ||
|
||
/** | ||
* Computes the estimated cost of executing the query using the provided function and compares it to | ||
* the `maxCost` parameter which determines the maximum allowable cost for a query. | ||
* | ||
* @param maxCost The maximum allowable cost for executing a query | ||
* @param f The field cost estimate function | ||
*/ | ||
def maxCost(maxCost: Double)(f: Field => Double): ValidationWrapper[Any] = | ||
maxCostOrError(maxCost)(f)(cost => ValidationError(s"Query costs too much: $cost. Max cost: $maxCost.", "")) | ||
|
||
/** | ||
* More powerful version of [[maxCost]] which allow you to also specify the error that is returned when the cost exceeds the maximum cost. | ||
* @param maxCost The total cost allowed for any one query | ||
* @param f The function used to evaluate the cost of a single field | ||
* @param error A function that will be provided the total estimated cost and must return the error that will be returned | ||
*/ | ||
def maxCostOrError(maxCost: Double)(f: Field => Double)(error: Double => ValidationError): ValidationWrapper[Any] = | ||
new ValidationWrapper[Any] { | ||
override def wrap[R1 <: Any]( | ||
process: Document => ZIO[R1, ValidationError, ExecutionRequest] | ||
): Document => ZIO[R1, ValidationError, ExecutionRequest] = | ||
(doc: Document) => | ||
for { | ||
req <- process(doc) | ||
cost = computeCost(req.field)(f) | ||
_ <- IO.when(cost > maxCost)(IO.fail(error(cost))) | ||
} yield req | ||
} | ||
|
||
/** | ||
* More powerful version of [[maxCost]] which allows the field computation function to specify a function which returns an effectful computation | ||
* for the cost of a field. | ||
* @param maxCost The total cost allowed for any one query | ||
* @param f The function used to evaluate the cost of a single field returning an effect which will result in the field cost as a double | ||
*/ | ||
def maxCostZIO[R](maxCost: Double)(f: Field => URIO[R, Double]): ValidationWrapper[R] = | ||
new ValidationWrapper[R] { | ||
override def wrap[R1 <: R]( | ||
process: Document => ZIO[R1, ValidationError, ExecutionRequest] | ||
): Document => ZIO[R1, ValidationError, ExecutionRequest] = | ||
(doc: Document) => | ||
for { | ||
req <- process(doc) | ||
cost <- computeCostZIO(req.field)(f) | ||
_ <- IO.when(cost > maxCost)( | ||
IO.fail(ValidationError(s"Query costs too much: $cost. Max cost: $maxCost.", "")) | ||
) | ||
} yield req | ||
} | ||
|
||
private def costWrapper(total: Ref[Double])(f: Field => Double): ValidationWrapper[Any] = | ||
new ValidationWrapper[Any] { | ||
override def wrap[R1 <: Any]( | ||
process: Document => ZIO[R1, ValidationError, ExecutionRequest] | ||
): Document => ZIO[R1, ValidationError, ExecutionRequest] = | ||
(doc: Document) => | ||
for { | ||
req <- process(doc) | ||
_ <- total.set(computeCost(req.field)(f)) | ||
} yield req | ||
} | ||
|
||
private def costWrapperZIO[R](total: Ref[Double])(f: Field => URIO[R, Double]): ValidationWrapper[R] = | ||
new ValidationWrapper[R] { | ||
override def wrap[R1 <: R]( | ||
process: Document => ZIO[R1, ValidationError, ExecutionRequest] | ||
): Document => ZIO[R1, ValidationError, ExecutionRequest] = | ||
(doc: Document) => | ||
for { | ||
req <- process(doc) | ||
_ <- computeCostZIO(req.field)(f) >>= total.set | ||
} yield req | ||
} | ||
|
||
private def addCostToExtensions( | ||
cost: Double, | ||
resp: GraphQLResponse[CalibanError] | ||
): UIO[GraphQLResponse[CalibanError]] = | ||
UIO.succeed( | ||
resp.copy( | ||
extensions = Some( | ||
ObjectValue( | ||
resp.extensions.foldLeft[List[(String, ResponseValue)]]( | ||
List(COST_EXTENSION_NAME -> FloatValue(cost)) | ||
)(_ ++ _.fields) | ||
) | ||
) | ||
) | ||
) | ||
|
||
private def costOverall[R]( | ||
f: GraphQLResponse[CalibanError] => URIO[R, GraphQLResponse[CalibanError]] | ||
): OverallWrapper[R] = | ||
new OverallWrapper[R] { | ||
override def wrap[R1 <: R]( | ||
process: GraphQLRequest => ZIO[R1, Nothing, GraphQLResponse[CalibanError]] | ||
): GraphQLRequest => ZIO[R1, Nothing, GraphQLResponse[CalibanError]] = | ||
(req: GraphQLRequest) => process(req).flatMap(f) | ||
} | ||
|
||
private def computeCost(field: Field)(f: Field => Double): Double = { | ||
@tailrec | ||
def go(fields: List[Field], total: Double): Double = fields match { | ||
case Nil => total | ||
case head :: tail => go(head.fields ++ tail, f(head) + total) | ||
} | ||
|
||
go(List(field), 0.0) | ||
} | ||
|
||
private def computeCostZIO[R](field: Field)(f: Field => URIO[R, Double]): URIO[R, Double] = { | ||
@tailrec | ||
def go(fields: List[Field], result: List[URIO[R, Double]]): List[URIO[R, Double]] = fields match { | ||
case Nil => result | ||
case head :: tail => | ||
go(head.fields ++ tail, f(head) :: result) | ||
} | ||
|
||
URIO.mergeAllPar(go(List(field), Nil))(0.0)(_ + _) | ||
} | ||
|
||
} |
Oops, something went wrong.