Skip to content

Commit

Permalink
Optimize unnecessary .map when handling nullable fields (#2044)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou authored Dec 17, 2023
1 parent 6c1dac4 commit 81d9b9d
Showing 1 changed file with 44 additions and 32 deletions.
76 changes: 44 additions & 32 deletions core/src/main/scala/caliban/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import caliban.schema.Step._
import caliban.schema.{ ReducedStep, Step, Types }
import caliban.wrappers.Wrapper.FieldWrapper
import zio._
import zio.query.{ Cache, UQuery, URQuery, ZQuery }
import zio.query._
import zio.stream.ZStream

import scala.annotation.tailrec
Expand Down Expand Up @@ -46,15 +46,16 @@ object Executor {
isTopLevelField: Boolean
)(
as: A => ZQuery[R, E, B]
)(implicit bf: BuildFrom[Coll[A], B, Coll[B]]): ZQuery[R, E, Coll[B]] =
if (in.sizeCompare(1) == 0) as(in.head).map(bf.newBuilder(in).+=(_).result())
else if (isTopLevelField && isMutation) ZQuery.foreach(in)(as)
else
queryExecution match {
case QueryExecution.Batched => ZQuery.foreachBatched(in)(as)
case QueryExecution.Parallel => ZQuery.foreachPar(in)(as)
case QueryExecution.Sequential => ZQuery.foreach(in)(as)
}
)(implicit bf: BuildFrom[Coll[A], B, Coll[B]]): ZQuery[R, E, Coll[B]] = {
val sc = in.sizeCompare(1)
queryExecution match {
case _ if sc == 0 => as(in.head).map(bf.newBuilder(in).+=(_).result())
case _ if isTopLevelField && isMutation => ZQuery.foreach(in)(as)
case QueryExecution.Batched => ZQuery.foreachBatched(in)(as)
case QueryExecution.Parallel => ZQuery.foreachPar(in)(as)
case QueryExecution.Sequential => ZQuery.foreach(in)(as)
}
}

def reduceStep(
step: Step[R],
Expand Down Expand Up @@ -113,26 +114,30 @@ object Executor {
reduceList(lb.result(), Types.listOf(currentField.fieldType).fold(false)(_.isNullable))
}

def reduceQuery(query: ZQuery[R, Throwable, Step[R]]) =
ReducedStep.QueryStep(
query.foldCauseQuery(
e => ZQuery.failCause(effectfulExecutionError(path, Some(currentField.locationInfo), e)),
a => ZQuery.succeed(reduceStep(a, currentField, arguments, path))
)
)

step match {
case s @ PureStep(EnumValue(v)) =>
// special case of an hybrid union containing case objects, those should return an object instead of a string
val obj = currentField.fields.view.filter(_._condition.forall(_.contains(v))).collectFirst {
currentField.fields.view.filter(_._condition.forall(_.contains(v))).collectFirst {
case f if f.name == "__typename" =>
ObjectValue(List(f.aliasedName -> StringValue(v)))
case f if f.name == "_" =>
NullValue
} match {
case Some(v) => PureStep(v)
case None => s
}
obj.fold(s)(PureStep(_))
case s: PureStep => s
case FunctionStep(step) => reduceStep(step(arguments), currentField, Map.empty, path)
case MetadataFunctionStep(step) => reduceStep(step(currentField), currentField, arguments, path)
case QueryStep(inner) =>
ReducedStep.QueryStep(
inner.foldCauseQuery(
e => ZQuery.failCause(effectfulExecutionError(path, Some(currentField.locationInfo), e)),
a => ZQuery.succeed(reduceStep(a, currentField, arguments, path))
)
)
case QueryStep(inner) => reduceQuery(inner)
case ObjectStep(objectName, fields) => reduceObjectStep(objectName, fields)
case ListStep(steps) => reduceListStep(steps)
case StreamStep(stream) =>
Expand Down Expand Up @@ -160,7 +165,7 @@ object Executor {
): URQuery[R, ResponseValue] = {

def handleError(error: ExecutionError): UQuery[ResponseValue] =
ZQuery.fromZIO(errors.update(error :: _)).as(NullValue)
ZQuery.fromZIO(errors.update(error :: _).as(NullValue))

def wrap(query: ExecutionQuery[ResponseValue], isPure: Boolean, fieldInfo: FieldInfo) = {
@tailrec
Expand All @@ -177,14 +182,15 @@ object Executor {
def objectFieldQuery(name: String, step: ReducedStep[R], info: FieldInfo) = {
val q = wrap(loop(step), step.isPure, info)

if (info.details.fieldType.isNullable) q.catchAll(handleError).map((name, _))
else q.map((name, _))
if (info.details.fieldType.isNullable) {
q.foldQuery(
handleError(_).map((name, _)),
v => ZQuery.succeed((name, v))
)
} else q.map((name, _))
}

def makeObjectQuery(
steps: List[(String, ReducedStep[R], FieldInfo)],
isTopLevelField: Boolean
): ExecutionQuery[ResponseValue] = {
def makeObjectQuery(steps: List[(String, ReducedStep[R], FieldInfo)], isTopLevelField: Boolean) = {
def collectAllQueries() =
collectAll(steps, isTopLevelField)((objectFieldQuery _).tupled).map(ObjectValue.apply)

Expand All @@ -202,13 +208,19 @@ object Executor {
remaining = remaining.tail
}

collectAll(queries.result(), isTopLevelField)((objectFieldQuery _).tupled).map { results =>
var i = -1
ObjectValue(resolved.mapInPlace {
case null => i += 1; results(i)
case t => t
}.result())
def combineResults(fromQueries: Vector[(String, ResponseValue)]) = {
val lb = List.newBuilder[(String, ResponseValue)]
var i = -1
val iter = resolved.iterator
while (iter.hasNext)
lb += (iter.next() match {
case null => i += 1; fromQueries(i)
case t => t
})
ObjectValue(lb.result())
}

collectAll(queries.result(), isTopLevelField)((objectFieldQuery _).tupled).map(combineResults)
}

if (wrapPureValues || !steps.exists(_._2.isPure)) collectAllQueries()
Expand Down

0 comments on commit 81d9b9d

Please sign in to comment.