From 98e7f418f857547cfefe8cc648f46d75b43bfa85 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Sat, 16 Dec 2023 23:16:24 +1100 Subject: [PATCH] Optimize unnecessary .map when handling nullable fields --- .../scala/caliban/execution/Executor.scala | 76 +++++++++++-------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index 2739756e41..1311fde84a 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -11,7 +11,7 @@ import caliban.schema.Step._ import caliban.schema.{ ReducedStep, Step, Types } import caliban.wrappers.Wrapper.FieldWrapper import zio._ -import zio.query.{ Cache, UQuery, URQuery, ZQuery } +import zio.query._ import zio.stream.ZStream import scala.annotation.tailrec @@ -46,15 +46,16 @@ object Executor { isTopLevelField: Boolean )( as: A => ZQuery[R, E, B] - )(implicit bf: BuildFrom[Coll[A], B, Coll[B]]): ZQuery[R, E, Coll[B]] = - if (in.sizeCompare(1) == 0) as(in.head).map(bf.newBuilder(in).+=(_).result()) - else if (isTopLevelField && isMutation) ZQuery.foreach(in)(as) - else - queryExecution match { - case QueryExecution.Batched => ZQuery.foreachBatched(in)(as) - case QueryExecution.Parallel => ZQuery.foreachPar(in)(as) - case QueryExecution.Sequential => ZQuery.foreach(in)(as) - } + )(implicit bf: BuildFrom[Coll[A], B, Coll[B]]): ZQuery[R, E, Coll[B]] = { + val sc = in.sizeCompare(1) + queryExecution match { + case _ if sc == 0 => as(in.head).map(bf.newBuilder(in).+=(_).result()) + case _ if isTopLevelField && isMutation => ZQuery.foreach(in)(as) + case QueryExecution.Batched => ZQuery.foreachBatched(in)(as) + case QueryExecution.Parallel => ZQuery.foreachPar(in)(as) + case QueryExecution.Sequential => ZQuery.foreach(in)(as) + } + } def reduceStep( step: Step[R], @@ -113,26 +114,30 @@ object Executor { reduceList(lb.result(), Types.listOf(currentField.fieldType).fold(false)(_.isNullable)) } + def reduceQuery(query: ZQuery[R, Throwable, Step[R]]) = + ReducedStep.QueryStep( + query.foldCauseQuery( + e => ZQuery.failCause(effectfulExecutionError(path, Some(currentField.locationInfo), e)), + a => ZQuery.succeed(reduceStep(a, currentField, arguments, path)) + ) + ) + step match { case s @ PureStep(EnumValue(v)) => // special case of an hybrid union containing case objects, those should return an object instead of a string - val obj = currentField.fields.view.filter(_._condition.forall(_.contains(v))).collectFirst { + currentField.fields.view.filter(_._condition.forall(_.contains(v))).collectFirst { case f if f.name == "__typename" => ObjectValue(List(f.aliasedName -> StringValue(v))) case f if f.name == "_" => NullValue + } match { + case Some(v) => PureStep(v) + case None => s } - obj.fold(s)(PureStep(_)) case s: PureStep => s case FunctionStep(step) => reduceStep(step(arguments), currentField, Map.empty, path) case MetadataFunctionStep(step) => reduceStep(step(currentField), currentField, arguments, path) - case QueryStep(inner) => - ReducedStep.QueryStep( - inner.foldCauseQuery( - e => ZQuery.failCause(effectfulExecutionError(path, Some(currentField.locationInfo), e)), - a => ZQuery.succeed(reduceStep(a, currentField, arguments, path)) - ) - ) + case QueryStep(inner) => reduceQuery(inner) case ObjectStep(objectName, fields) => reduceObjectStep(objectName, fields) case ListStep(steps) => reduceListStep(steps) case StreamStep(stream) => @@ -160,7 +165,7 @@ object Executor { ): URQuery[R, ResponseValue] = { def handleError(error: ExecutionError): UQuery[ResponseValue] = - ZQuery.fromZIO(errors.update(error :: _)).as(NullValue) + ZQuery.fromZIO(errors.update(error :: _).as(NullValue)) def wrap(query: ExecutionQuery[ResponseValue], isPure: Boolean, fieldInfo: FieldInfo) = { @tailrec @@ -177,14 +182,15 @@ object Executor { def objectFieldQuery(name: String, step: ReducedStep[R], info: FieldInfo) = { val q = wrap(loop(step), step.isPure, info) - if (info.details.fieldType.isNullable) q.catchAll(handleError).map((name, _)) - else q.map((name, _)) + if (info.details.fieldType.isNullable) { + q.foldQuery( + handleError(_).map((name, _)), + v => ZQuery.succeed((name, v)) + ) + } else q.map((name, _)) } - def makeObjectQuery( - steps: List[(String, ReducedStep[R], FieldInfo)], - isTopLevelField: Boolean - ): ExecutionQuery[ResponseValue] = { + def makeObjectQuery(steps: List[(String, ReducedStep[R], FieldInfo)], isTopLevelField: Boolean) = { def collectAllQueries() = collectAll(steps, isTopLevelField)((objectFieldQuery _).tupled).map(ObjectValue.apply) @@ -202,13 +208,19 @@ object Executor { remaining = remaining.tail } - collectAll(queries.result(), isTopLevelField)((objectFieldQuery _).tupled).map { results => - var i = -1 - ObjectValue(resolved.mapInPlace { - case null => i += 1; results(i) - case t => t - }.result()) + def combineResults(fromQueries: Vector[(String, ResponseValue)]) = { + val lb = List.newBuilder[(String, ResponseValue)] + var i = -1 + val iter = resolved.iterator + while (iter.hasNext) + lb += (iter.next() match { + case null => i += 1; fromQueries(i) + case t => t + }) + ObjectValue(lb.result()) } + + collectAll(queries.result(), isTopLevelField)((objectFieldQuery _).tupled).map(combineResults) } if (wrapPureValues || !steps.exists(_._2.isPure)) collectAllQueries()