Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow parallel execution of mutation non-top level fields #2040

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions core/src/main/scala/caliban/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,23 @@ object Executor {
)(implicit trace: Trace): URIO[R, GraphQLResponse[CalibanError]] = {
val wrapPureValues = fieldWrappers.exists(_.wrapPureValues)
val isDeferredEnabled = featureSet(Feature.Defer)
val isMutation = request.operationType == OperationType.Mutation

type ExecutionQuery[+A] = ZQuery[R, ExecutionError, A]

val execution = request.operationType match {
case OperationType.Query => queryExecution
case OperationType.Mutation => QueryExecution.Sequential
case OperationType.Subscription => QueryExecution.Sequential
}

def collectAll[E, A, B, Coll[+V] <: Iterable[V]](
in: Coll[A]
in: Coll[A],
isTopLevelField: Boolean
)(
as: A => ZQuery[R, E, B]
)(implicit bf: BuildFrom[Coll[A], B, Coll[B]]): ZQuery[R, E, Coll[B]] =
if (in.sizeCompare(1) == 0) as(in.head).map(bf.newBuilder(in).+=(_).result())
else if (isTopLevelField && isMutation) ZQuery.foreach(in)(as)
else
execution match {
case QueryExecution.Sequential => ZQuery.foreach(in)(as)
case QueryExecution.Parallel => ZQuery.foreachPar(in)(as)
queryExecution match {
case QueryExecution.Batched => ZQuery.foreachBatched(in)(as)
case QueryExecution.Parallel => ZQuery.foreachPar(in)(as)
case QueryExecution.Sequential => ZQuery.foreach(in)(as)
}

def reduceStep(
Expand Down Expand Up @@ -178,9 +175,12 @@ object Executor {
else q.map((name, _))
}

def makeObjectQuery(steps: List[(String, ReducedStep[R], FieldInfo)]) = {
def makeObjectQuery(
steps: List[(String, ReducedStep[R], FieldInfo)],
isTopLevelField: Boolean
): ExecutionQuery[ResponseValue] = {
def collectAllQueries() =
collectAll(steps)((objectFieldQuery _).tupled).map(ObjectValue.apply)
collectAll(steps, isTopLevelField)((objectFieldQuery _).tupled).map(ObjectValue.apply)

def collectMixed() = {
val resolved = ListBuffer.empty[(String, ResponseValue)]
Expand All @@ -196,7 +196,7 @@ object Executor {
remaining = remaining.tail
}

collectAll(queries.result())((objectFieldQuery _).tupled).map { results =>
collectAll(queries.result(), isTopLevelField)((objectFieldQuery _).tupled).map { results =>
var i = -1
ObjectValue(resolved.mapInPlace {
case null => i += 1; results(i)
Expand All @@ -209,15 +209,15 @@ object Executor {
else collectMixed()
}

def makeListQuery(steps: List[ReducedStep[R]], areItemsNullable: Boolean) =
collectAll(steps)(if (areItemsNullable) loop(_).catchAll(handleError) else loop)
def makeListQuery(steps: List[ReducedStep[R]], areItemsNullable: Boolean): ExecutionQuery[ResponseValue] =
collectAll(steps, isTopLevelField = false)(if (areItemsNullable) loop(_).catchAll(handleError) else loop(_))
.map(ListValue.apply)

def loop(step: ReducedStep[R]): ExecutionQuery[ResponseValue] =
def loop(step: ReducedStep[R], isTopLevelField: Boolean = false): ExecutionQuery[ResponseValue] =
step match {
case PureStep(value) => ZQuery.succeed(value)
case ReducedStep.QueryStep(step) => step.flatMap(loop)
case ReducedStep.ObjectStep(steps) => makeObjectQuery(steps)
case ReducedStep.QueryStep(step) => step.flatMap(loop(_))
case ReducedStep.ObjectStep(steps) => makeObjectQuery(steps, isTopLevelField)
case ReducedStep.ListStep(steps, areItemsNullable) => makeListQuery(steps, areItemsNullable)
case ReducedStep.StreamStep(stream) =>
ZQuery
Expand All @@ -235,7 +235,7 @@ object Executor {
}
ZQuery.fromZIO(deferred.update(deferredSteps ::: _)) *> loop(obj)
}
loop(step).catchAll(handleError)
loop(step, isTopLevelField = true).catchAll(handleError)
}

def runQuery(step: ReducedStep[R], cache: Cache) =
Expand Down
58 changes: 53 additions & 5 deletions core/src/test/scala/caliban/execution/ExecutionSpec.scala
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
package caliban.execution

import java.util.UUID
import caliban.CalibanError.ExecutionError
import caliban.Macros.gqldoc
import caliban.TestUtils._
import caliban.Value.{ BooleanValue, IntValue, NullValue, StringValue }
import caliban._
import caliban.introspection.adt.__Type
import caliban.parsing.adt.LocationInfo
import caliban.schema.Annotations.{ GQLInterface, GQLName, GQLValueType }
import caliban.schema._
import caliban.schema.Schema.auto._
import caliban.schema.ArgBuilder.auto._
import caliban._
import zio.{ FiberRef, IO, Task, UIO, ZIO, ZLayer }
import caliban.schema.Schema.auto._
import caliban.schema._
import zio._
import zio.stream.ZStream
import zio.test._

import java.util.UUID

object ExecutionSpec extends ZIOSpecDefault {

@GQLInterface
Expand Down Expand Up @@ -1319,6 +1320,53 @@ object ExecutionSpec extends ZIOSpecDefault {
)
)

},
test("top-level fields are executed sequentially for mutations") {
case class Foo(field1: UIO[Unit], field2: UIO[Unit])
case class Mutations(
mutation1: CharacterArgs => UIO[Foo],
mutation2: CharacterArgs => UIO[Foo]
)

val ref = Unsafe.unsafe(implicit u => Ref.unsafe.make(List.empty[String]))
def add(name: String, d: Duration = 1.second) = ref.update(name :: _).delay(d)
def foo(prefix: String) = ZIO.succeed(Foo(add(s"$prefix-f1", 1500.millis), add(s"$prefix-f2", 2.seconds)))

val interpreter = graphQL(
RootResolver(
resolverIO.queryResolver,
Mutations(
_ => add("m1") *> foo("m1"),
_ => add("m2") *> foo("m2")
)
)
).interpreter

def adjustAndGet(d: Duration = 1.second) = TestClock.adjust(d) *> ref.get

for {
i <- interpreter
_ <- i.execute(gqldoc("""mutation {
mutation1(name: "foo") { field1 field2 }
mutation2(name: "bar") { field1 field2 }
}"""))
.fork
r1 <- ref.get
r2 <- adjustAndGet()
r3 <- adjustAndGet()
r4 <- adjustAndGet()
r5 <- adjustAndGet()
r6 <- adjustAndGet()
r7 <- adjustAndGet()
} yield assertTrue(
r1 == Nil,
r2 == List("m1"),
r3 == List("m1"),
r4 == List("m1-f2", "m1-f1", "m1"),
r5 == List("m2", "m1-f2", "m1-f1", "m1"),
r6 == List("m2", "m1-f2", "m1-f1", "m1"),
r7 == List("m2-f2", "m2-f1", "m2", "m1-f2", "m1-f1", "m1")
)
}
)
}