Skip to content

Commit

Permalink
Fix incorrect wrapping of pure values by some FieldWrappers (#2162)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou authored Mar 15, 2024
1 parent d231389 commit 47e92f9
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 33 deletions.
50 changes: 26 additions & 24 deletions core/src/main/scala/caliban/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,30 +159,33 @@ object Executor {
} else (Nil, filteredFields.map(reduceField))
}

val eagerReduced = reduceObject(eager, wrapPureValues)
val eagerReduced = reduceObject(eager)
deferred match {
case Nil => eagerReduced
case d =>
DeferStep(
eagerReduced,
d.groupBy(_._1).toList.map { case (label, labelAndFields) =>
val (_, fields) = labelAndFields.unzip
reduceObject(fields, wrapPureValues) -> label
reduceObject(fields) -> label
},
path
)
}
}

def reduceListStep(steps: List[Step[R]]): ReducedStep[R] = {
def reduceMixed(head: ReducedStep[R], remaining0: List[Step[R]]): ReducedStep[R] = {

def reduceToListStep(head: ReducedStep[R], remaining0: List[Step[R]]): ReducedStep[R] = {
var i = 1
val nil = Nil
val lb = ListBuffer.empty[ReducedStep[R]]
var isPure = wrapPureValues && head.isPure
lb addOne head
var remaining = remaining0
while (remaining ne nil) {
val step = reduceStep(remaining.head, currentField, arguments, PathValue.Index(i) :: path)
if (isPure && !step.isPure) isPure = false
lb addOne step
i += 1
remaining = remaining.tail
Expand All @@ -192,11 +195,12 @@ object Executor {
Types.listOf(currentField.fieldType) match {
case Some(tpe) => tpe.isNullable
case None => false
}
},
isPure
)
}

def reducePures(head: PureStep, remaining0: List[Step[R]]): ReducedStep[R] = {
def reduceToPureStep(head: PureStep, remaining0: List[Step[R]]): ReducedStep[R] = {
var i = 1
val nil = Nil
val lb = ListBuffer.empty[ResponseValue]
Expand All @@ -213,13 +217,14 @@ object Executor {

if (steps.isEmpty) PureStep(ListValue(Nil))
else {
val step = reduceStep(steps.head, currentField, arguments, PathValue.Index(0) :: path)
if (step.isPure) {
reduceStep(steps.head, currentField, arguments, PathValue.Index(0) :: path) match {
// In 99.99% of the cases, if the head is pure, all the other elements will be pure as well but we catch that error just in case
// NOTE: Our entire test suite passes without catching the error
try reducePures(step.asInstanceOf[PureStep], steps.tail)
catch { case _: ClassCastException => reduceMixed(step, steps.tail) }
} else reduceMixed(step, steps.tail)
case step: PureStep =>
try reduceToPureStep(step, steps.tail)
catch { case _: ClassCastException => reduceToListStep(step, steps.tail) }
case step => reduceToListStep(step, steps.tail)
}
}
}

Expand Down Expand Up @@ -322,23 +327,20 @@ object Executor {
): FieldInfo =
FieldInfo(aliasedName, field, path, fieldDirectives, field.parentType)

private def reduceObject(
items: List[(String, ReducedStep[R], FieldInfo)],
wrapPureValues: Boolean
): ReducedStep[R] = {
private def reduceObject(items: List[(String, ReducedStep[R], FieldInfo)]): ReducedStep[R] = {
var hasPures = false
var hasQueries = false
val nil = Nil
var hasQueries = wrapPureValues
var remaining = items
while ((remaining ne nil) && !(hasPures && hasQueries)) {
val isPure = remaining.head._2.isPure
if (!hasPures && isPure) hasPures = true
else if (!hasQueries && !isPure) hasQueries = true
if (isPure && !hasPures) hasPures = true
else if (!isPure && !hasQueries) hasQueries = true
else ()
remaining = remaining.tail
}

if (hasQueries) ReducedStep.ObjectStep(items, hasPures)
if (hasQueries || wrapPureValues) ReducedStep.ObjectStep(items, hasPures, !hasQueries)
else
PureStep(
ObjectValue(items.asInstanceOf[List[(String, PureStep, FieldInfo)]].map { case (k, v, _) => (k, v.value) })
Expand Down Expand Up @@ -475,11 +477,11 @@ object Executor {

def loop(step: ReducedStep[R], isTopLevelField: Boolean = false): ExecutionQuery[ResponseValue] =
step match {
case PureStep(value) => ZQuery.succeed(value)
case ReducedStep.QueryStep(step) => step.flatMap(loop(_))
case ReducedStep.ObjectStep(steps, hasPureFields) => makeObjectQuery(steps, hasPureFields, isTopLevelField)
case ReducedStep.ListStep(steps, areItemsNullable) => makeListQuery(steps, areItemsNullable)
case ReducedStep.StreamStep(stream) =>
case PureStep(value) => ZQuery.succeed(value)
case ReducedStep.QueryStep(step) => step.flatMap(loop(_))
case ReducedStep.ObjectStep(steps, hasPureFields, _) => makeObjectQuery(steps, hasPureFields, isTopLevelField)
case ReducedStep.ListStep(steps, areItemsNullable, _) => makeListQuery(steps, areItemsNullable)
case ReducedStep.StreamStep(stream) =>
ZQuery
.environmentWith[R](env =>
ResponseValue.StreamValue(
Expand All @@ -488,7 +490,7 @@ object Executor {
}.provideEnvironment(env)
)
)
case ReducedStep.DeferStep(obj, nextSteps, path) =>
case ReducedStep.DeferStep(obj, nextSteps, path) =>
val deferredSteps = nextSteps.map { case (step, label) =>
Deferred(path, step, label)
}
Expand Down
37 changes: 28 additions & 9 deletions core/src/main/scala/caliban/schema/Step.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,37 @@ object Step {
}

sealed abstract class ReducedStep[-R] { self =>
final def isPure: Boolean = self.isInstanceOf[PureStep]
def isPure: Boolean
}

object ReducedStep {
case class ListStep[-R](steps: List[ReducedStep[R]], areItemsNullable: Boolean) extends ReducedStep[R]
case class ObjectStep[-R](fields: List[(String, ReducedStep[R], FieldInfo)], hasPureFields: Boolean)
extends ReducedStep[R]
case class QueryStep[-R](query: ZQuery[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R]
case class StreamStep[-R](inner: ZStream[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R]
case class DeferStep[-R](
final case class ListStep[-R](
steps: List[ReducedStep[R]],
areItemsNullable: Boolean,
isPure: Boolean
) extends ReducedStep[R]

final case class ObjectStep[-R](
fields: List[(String, ReducedStep[R], FieldInfo)],
hasPureFields: Boolean,
isPure: Boolean
) extends ReducedStep[R]

final case class QueryStep[-R](query: ZQuery[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R] {
final val isPure = false
}

final case class StreamStep[-R](inner: ZStream[R, ExecutionError, ReducedStep[R]]) extends ReducedStep[R] {
final val isPure = false
}

final case class DeferStep[-R](
obj: ReducedStep[R],
deferred: List[(ReducedStep[R], Option[String])],
path: List[PathValue]
) extends ReducedStep[R]
) extends ReducedStep[R] {
final val isPure = false
}

// PureStep is both a Step and a ReducedStep so it is defined outside this object
// This is to avoid boxing/unboxing pure values during step reduction
Expand All @@ -105,4 +122,6 @@ object ReducedStep {
*
* @param value the response value to return for that step
*/
case class PureStep(value: ResponseValue) extends ReducedStep[Any] with Step[Any]
final case class PureStep(value: ResponseValue) extends ReducedStep[Any] with Step[Any] {
final val isPure = true
}
30 changes: 30 additions & 0 deletions core/src/test/scala/caliban/wrappers/WrappersSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,36 @@ object WrappersSpec extends ZIOSpecDefault {
counter <- ref.get
} yield assertTrue(counter == 2)
},
// i2161
test("wrapPureValues true and false") {
case class Obj1(a1: List[Obj2])
case class Obj2(a2: Int)
case class Test(a0: Obj1, b: UIO[Int])
for {
ref1 <- Ref.make[Int](0)
wrapper1 = new FieldWrapper[Any](true) {
def wrap[R1 <: Any](
query: ZQuery[R1, ExecutionError, ResponseValue],
info: FieldInfo
): ZQuery[R1, ExecutionError, ResponseValue] =
ZQuery.fromZIO(ref1.update(_ + 1)) *> query
}
ref2 <- Ref.make[Int](0)
wrapper2 = new FieldWrapper[Any](false) {
def wrap[R1 <: Any](
query: ZQuery[R1, ExecutionError, ResponseValue],
info: FieldInfo
): ZQuery[R1, ExecutionError, ResponseValue] =
ZQuery.fromZIO(ref2.update(_ + 1)) *> query
}
interpreter <-
(graphQL(RootResolver(Test(Obj1(List(Obj2(1))), ZIO.succeed(2)))) @@ wrapper1 @@ wrapper2).interpreter.orDie
query = gqldoc("""{ a0 { a1 { a2 } } b }""")
_ <- interpreter.execute(query)
counter1 <- ref1.get
counter2 <- ref2.get
} yield assertTrue(counter1 == 4, counter2 == 1)
},
test("Max fields") {
case class A(b: B)
case class B(c: Int)
Expand Down

0 comments on commit 47e92f9

Please sign in to comment.