From 47e92f9fedaddfbaf5e30c73e15f314d240a9ab2 Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Fri, 15 Mar 2024 23:42:36 +0800 Subject: [PATCH] Fix incorrect wrapping of pure values by some `FieldWrapper`s (#2162) --- .../scala/caliban/execution/Executor.scala | 50 ++++++++++--------- core/src/main/scala/caliban/schema/Step.scala | 37 ++++++++++---- .../scala/caliban/wrappers/WrappersSpec.scala | 30 +++++++++++ 3 files changed, 84 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index b820253d6..51c8b2952 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -159,7 +159,7 @@ object Executor { } else (Nil, filteredFields.map(reduceField)) } - val eagerReduced = reduceObject(eager, wrapPureValues) + val eagerReduced = reduceObject(eager) deferred match { case Nil => eagerReduced case d => @@ -167,7 +167,7 @@ object Executor { eagerReduced, d.groupBy(_._1).toList.map { case (label, labelAndFields) => val (_, fields) = labelAndFields.unzip - reduceObject(fields, wrapPureValues) -> label + reduceObject(fields) -> label }, path ) @@ -175,14 +175,17 @@ object Executor { } def reduceListStep(steps: List[Step[R]]): ReducedStep[R] = { - def reduceMixed(head: ReducedStep[R], remaining0: List[Step[R]]): ReducedStep[R] = { + + def reduceToListStep(head: ReducedStep[R], remaining0: List[Step[R]]): ReducedStep[R] = { var i = 1 val nil = Nil val lb = ListBuffer.empty[ReducedStep[R]] + var isPure = wrapPureValues && head.isPure lb addOne head var remaining = remaining0 while (remaining ne nil) { val step = reduceStep(remaining.head, currentField, arguments, PathValue.Index(i) :: path) + if (isPure && !step.isPure) isPure = false lb addOne step i += 1 remaining = remaining.tail @@ -192,11 +195,12 @@ object Executor { Types.listOf(currentField.fieldType) match { case Some(tpe) => tpe.isNullable case None => false - } + }, + isPure ) } - def reducePures(head: PureStep, remaining0: List[Step[R]]): ReducedStep[R] = { + def reduceToPureStep(head: PureStep, remaining0: List[Step[R]]): ReducedStep[R] = { var i = 1 val nil = Nil val lb = ListBuffer.empty[ResponseValue] @@ -213,13 +217,14 @@ object Executor { if (steps.isEmpty) PureStep(ListValue(Nil)) else { - val step = reduceStep(steps.head, currentField, arguments, PathValue.Index(0) :: path) - if (step.isPure) { + reduceStep(steps.head, currentField, arguments, PathValue.Index(0) :: path) match { // In 99.99% of the cases, if the head is pure, all the other elements will be pure as well but we catch that error just in case // NOTE: Our entire test suite passes without catching the error - try reducePures(step.asInstanceOf[PureStep], steps.tail) - catch { case _: ClassCastException => reduceMixed(step, steps.tail) } - } else reduceMixed(step, steps.tail) + case step: PureStep => + try reduceToPureStep(step, steps.tail) + catch { case _: ClassCastException => reduceToListStep(step, steps.tail) } + case step => reduceToListStep(step, steps.tail) + } } } @@ -322,23 +327,20 @@ object Executor { ): FieldInfo = FieldInfo(aliasedName, field, path, fieldDirectives, field.parentType) - private def reduceObject( - items: List[(String, ReducedStep[R], FieldInfo)], - wrapPureValues: Boolean - ): ReducedStep[R] = { + private def reduceObject(items: List[(String, ReducedStep[R], FieldInfo)]): ReducedStep[R] = { var hasPures = false + var hasQueries = false val nil = Nil - var hasQueries = wrapPureValues var remaining = items while ((remaining ne nil) && !(hasPures && hasQueries)) { val isPure = remaining.head._2.isPure - if (!hasPures && isPure) hasPures = true - else if (!hasQueries && !isPure) hasQueries = true + if (isPure && !hasPures) hasPures = true + else if (!isPure && !hasQueries) hasQueries = true else () remaining = remaining.tail } - if (hasQueries) ReducedStep.ObjectStep(items, hasPures) + if (hasQueries || wrapPureValues) ReducedStep.ObjectStep(items, hasPures, !hasQueries) else PureStep( ObjectValue(items.asInstanceOf[List[(String, PureStep, FieldInfo)]].map { case (k, v, _) => (k, v.value) }) @@ -475,11 +477,11 @@ object Executor { def loop(step: ReducedStep[R], isTopLevelField: Boolean = false): ExecutionQuery[ResponseValue] = step match { - case PureStep(value) => ZQuery.succeed(value) - case ReducedStep.QueryStep(step) => step.flatMap(loop(_)) - case ReducedStep.ObjectStep(steps, hasPureFields) => makeObjectQuery(steps, hasPureFields, isTopLevelField) - case ReducedStep.ListStep(steps, areItemsNullable) => makeListQuery(steps, areItemsNullable) - case ReducedStep.StreamStep(stream) => + case PureStep(value) => ZQuery.succeed(value) + case ReducedStep.QueryStep(step) => step.flatMap(loop(_)) + case ReducedStep.ObjectStep(steps, hasPureFields, _) => makeObjectQuery(steps, hasPureFields, isTopLevelField) + case ReducedStep.ListStep(steps, areItemsNullable, _) => makeListQuery(steps, areItemsNullable) + case ReducedStep.StreamStep(stream) => ZQuery .environmentWith[R](env => ResponseValue.StreamValue( @@ -488,7 +490,7 @@ object Executor { }.provideEnvironment(env) ) ) - case ReducedStep.DeferStep(obj, nextSteps, path) => + case ReducedStep.DeferStep(obj, nextSteps, path) => val deferredSteps = nextSteps.map { case (step, label) => Deferred(path, step, label) } diff --git a/core/src/main/scala/caliban/schema/Step.scala b/core/src/main/scala/caliban/schema/Step.scala index 01a56e578..1e4412ddc 100644 --- a/core/src/main/scala/caliban/schema/Step.scala +++ b/core/src/main/scala/caliban/schema/Step.scala @@ -78,20 +78,37 @@ object Step { } sealed abstract class ReducedStep[-R] { self => - final def isPure: Boolean = self.isInstanceOf[PureStep] + def isPure: Boolean } object ReducedStep { - case class ListStep[-R](steps: List[ReducedStep[R]], areItemsNullable: Boolean) extends ReducedStep[R] - case class ObjectStep[-R](fields: List[(String, ReducedStep[R], FieldInfo)], hasPureFields: Boolean) - extends ReducedStep[R] - case class QueryStep[-R](query: ZQuery[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R] - case class StreamStep[-R](inner: ZStream[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R] - case class DeferStep[-R]( + final case class ListStep[-R]( + steps: List[ReducedStep[R]], + areItemsNullable: Boolean, + isPure: Boolean + ) extends ReducedStep[R] + + final case class ObjectStep[-R]( + fields: List[(String, ReducedStep[R], FieldInfo)], + hasPureFields: Boolean, + isPure: Boolean + ) extends ReducedStep[R] + + final case class QueryStep[-R](query: ZQuery[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R] { + final val isPure = false + } + + final case class StreamStep[-R](inner: ZStream[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R] { + final val isPure = false + } + + final case class DeferStep[-R]( obj: ReducedStep[R], deferred: List[(ReducedStep[R], Option[String])], path: List[PathValue] - ) extends ReducedStep[R] + ) extends ReducedStep[R] { + final val isPure = false + } // PureStep is both a Step and a ReducedStep so it is defined outside this object // This is to avoid boxing/unboxing pure values during step reduction @@ -105,4 +122,6 @@ object ReducedStep { * * @param value the response value to return for that step */ -case class PureStep(value: ResponseValue) extends ReducedStep[Any] with Step[Any] +final case class PureStep(value: ResponseValue) extends ReducedStep[Any] with Step[Any] { + final val isPure = true +} diff --git a/core/src/test/scala/caliban/wrappers/WrappersSpec.scala b/core/src/test/scala/caliban/wrappers/WrappersSpec.scala index 004e6eb11..1e555af5b 100644 --- a/core/src/test/scala/caliban/wrappers/WrappersSpec.scala +++ b/core/src/test/scala/caliban/wrappers/WrappersSpec.scala @@ -64,6 +64,36 @@ object WrappersSpec extends ZIOSpecDefault { counter <- ref.get } yield assertTrue(counter == 2) }, + // i2161 + test("wrapPureValues true and false") { + case class Obj1(a1: List[Obj2]) + case class Obj2(a2: Int) + case class Test(a0: Obj1, b: UIO[Int]) + for { + ref1 <- Ref.make[Int](0) + wrapper1 = new FieldWrapper[Any](true) { + def wrap[R1 <: Any]( + query: ZQuery[R1, ExecutionError, ResponseValue], + info: FieldInfo + ): ZQuery[R1, ExecutionError, ResponseValue] = + ZQuery.fromZIO(ref1.update(_ + 1)) *> query + } + ref2 <- Ref.make[Int](0) + wrapper2 = new FieldWrapper[Any](false) { + def wrap[R1 <: Any]( + query: ZQuery[R1, ExecutionError, ResponseValue], + info: FieldInfo + ): ZQuery[R1, ExecutionError, ResponseValue] = + ZQuery.fromZIO(ref2.update(_ + 1)) *> query + } + interpreter <- + (graphQL(RootResolver(Test(Obj1(List(Obj2(1))), ZIO.succeed(2)))) @@ wrapper1 @@ wrapper2).interpreter.orDie + query = gqldoc("""{ a0 { a1 { a2 } } b }""") + _ <- interpreter.execute(query) + counter1 <- ref1.get + counter2 <- ref2.get + } yield assertTrue(counter1 == 4, counter2 == 1) + }, test("Max fields") { case class A(b: B) case class B(c: Int)