diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index ee7754e08c..2eea4267d6 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -37,26 +37,23 @@ object Executor { )(implicit trace: Trace): URIO[R, GraphQLResponse[CalibanError]] = { val wrapPureValues = fieldWrappers.exists(_.wrapPureValues) val isDeferredEnabled = featureSet(Feature.Defer) + val isMutation = request.operationType == OperationType.Mutation type ExecutionQuery[+A] = ZQuery[R, ExecutionError, A] - val execution = request.operationType match { - case OperationType.Query => queryExecution - case OperationType.Mutation => QueryExecution.Sequential - case OperationType.Subscription => QueryExecution.Sequential - } - def collectAll[E, A, B, Coll[+V] <: Iterable[V]]( - in: Coll[A] + in: Coll[A], + isTopLevelField: Boolean )( as: A => ZQuery[R, E, B] )(implicit bf: BuildFrom[Coll[A], B, Coll[B]]): ZQuery[R, E, Coll[B]] = if (in.sizeCompare(1) == 0) as(in.head).map(bf.newBuilder(in).+=(_).result()) + else if (isTopLevelField && isMutation) ZQuery.foreach(in)(as) else - execution match { - case QueryExecution.Sequential => ZQuery.foreach(in)(as) - case QueryExecution.Parallel => ZQuery.foreachPar(in)(as) + queryExecution match { case QueryExecution.Batched => ZQuery.foreachBatched(in)(as) + case QueryExecution.Parallel => ZQuery.foreachPar(in)(as) + case QueryExecution.Sequential => ZQuery.foreach(in)(as) } def reduceStep( @@ -178,9 +175,12 @@ object Executor { else q.map((name, _)) } - def makeObjectQuery(steps: List[(String, ReducedStep[R], FieldInfo)]) = { + def makeObjectQuery( + steps: List[(String, ReducedStep[R], FieldInfo)], + isTopLevelField: Boolean + ): ExecutionQuery[ResponseValue] = { def collectAllQueries() = - collectAll(steps)((objectFieldQuery _).tupled).map(ObjectValue.apply) + collectAll(steps, isTopLevelField)((objectFieldQuery _).tupled).map(ObjectValue.apply) def collectMixed() = { val resolved = ListBuffer.empty[(String, ResponseValue)] @@ -196,7 +196,7 @@ object Executor { remaining = remaining.tail } - collectAll(queries.result())((objectFieldQuery _).tupled).map { results => + collectAll(queries.result(), isTopLevelField)((objectFieldQuery _).tupled).map { results => var i = -1 ObjectValue(resolved.mapInPlace { case null => i += 1; results(i) @@ -209,15 +209,15 @@ object Executor { else collectMixed() } - def makeListQuery(steps: List[ReducedStep[R]], areItemsNullable: Boolean) = - collectAll(steps)(if (areItemsNullable) loop(_).catchAll(handleError) else loop) + def makeListQuery(steps: List[ReducedStep[R]], areItemsNullable: Boolean): ExecutionQuery[ResponseValue] = + collectAll(steps, isTopLevelField = false)(if (areItemsNullable) loop(_).catchAll(handleError) else loop(_)) .map(ListValue.apply) - def loop(step: ReducedStep[R]): ExecutionQuery[ResponseValue] = + def loop(step: ReducedStep[R], isTopLevelField: Boolean = false): ExecutionQuery[ResponseValue] = step match { case PureStep(value) => ZQuery.succeed(value) - case ReducedStep.QueryStep(step) => step.flatMap(loop) - case ReducedStep.ObjectStep(steps) => makeObjectQuery(steps) + case ReducedStep.QueryStep(step) => step.flatMap(loop(_)) + case ReducedStep.ObjectStep(steps) => makeObjectQuery(steps, isTopLevelField) case ReducedStep.ListStep(steps, areItemsNullable) => makeListQuery(steps, areItemsNullable) case ReducedStep.StreamStep(stream) => ZQuery @@ -235,7 +235,7 @@ object Executor { } ZQuery.fromZIO(deferred.update(deferredSteps ::: _)) *> loop(obj) } - loop(step).catchAll(handleError) + loop(step, isTopLevelField = true).catchAll(handleError) } def runQuery(step: ReducedStep[R], cache: Cache) = diff --git a/core/src/test/scala/caliban/execution/ExecutionSpec.scala b/core/src/test/scala/caliban/execution/ExecutionSpec.scala index 7b069bb998..11ee7c7679 100644 --- a/core/src/test/scala/caliban/execution/ExecutionSpec.scala +++ b/core/src/test/scala/caliban/execution/ExecutionSpec.scala @@ -1,21 +1,22 @@ package caliban.execution -import java.util.UUID import caliban.CalibanError.ExecutionError import caliban.Macros.gqldoc import caliban.TestUtils._ import caliban.Value.{ BooleanValue, IntValue, NullValue, StringValue } +import caliban._ import caliban.introspection.adt.__Type import caliban.parsing.adt.LocationInfo import caliban.schema.Annotations.{ GQLInterface, GQLName, GQLValueType } -import caliban.schema._ -import caliban.schema.Schema.auto._ import caliban.schema.ArgBuilder.auto._ -import caliban._ -import zio.{ FiberRef, IO, Task, UIO, ZIO, ZLayer } +import caliban.schema.Schema.auto._ +import caliban.schema._ +import zio._ import zio.stream.ZStream import zio.test._ +import java.util.UUID + object ExecutionSpec extends ZIOSpecDefault { @GQLInterface @@ -1319,6 +1320,53 @@ object ExecutionSpec extends ZIOSpecDefault { ) ) + }, + test("top-level fields are executed sequentially for mutations") { + case class Foo(field1: UIO[Unit], field2: UIO[Unit]) + case class Mutations( + mutation1: CharacterArgs => UIO[Foo], + mutation2: CharacterArgs => UIO[Foo] + ) + + val ref = Unsafe.unsafe(implicit u => Ref.unsafe.make(List.empty[String])) + def add(name: String, d: Duration = 1.second) = ref.update(name :: _).delay(d) + def foo(prefix: String) = ZIO.succeed(Foo(add(s"$prefix-f1", 1500.millis), add(s"$prefix-f2", 2.seconds))) + + val interpreter = graphQL( + RootResolver( + resolverIO.queryResolver, + Mutations( + _ => add("m1") *> foo("m1"), + _ => add("m2") *> foo("m2") + ) + ) + ).interpreter + + def adjustAndGet(d: Duration = 1.second) = TestClock.adjust(d) *> ref.get + + for { + i <- interpreter + _ <- i.execute(gqldoc("""mutation { + mutation1(name: "foo") { field1 field2 } + mutation2(name: "bar") { field1 field2 } + }""")) + .fork + r1 <- ref.get + r2 <- adjustAndGet() + r3 <- adjustAndGet() + r4 <- adjustAndGet() + r5 <- adjustAndGet() + r6 <- adjustAndGet() + r7 <- adjustAndGet() + } yield assertTrue( + r1 == Nil, + r2 == List("m1"), + r3 == List("m1"), + r4 == List("m1-f2", "m1-f1", "m1"), + r5 == List("m2", "m1-f2", "m1-f1", "m1"), + r6 == List("m2", "m1-f2", "m1-f1", "m1"), + r7 == List("m2-f2", "m2-f1", "m2", "m1-f2", "m1-f1", "m1") + ) } ) }