Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow parallel execution of mutation non-top level fields #2040

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions core/src/main/scala/caliban/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,29 @@ object Executor {
)(implicit trace: Trace): URIO[R, GraphQLResponse[CalibanError]] = {
val wrapPureValues = fieldWrappers.exists(_.wrapPureValues)
val isDeferredEnabled = featureSet(Feature.Defer)
val isMutation = request.operationType == OperationType.Mutation

type ExecutionQuery[+A] = ZQuery[R, ExecutionError, A]

val execution = request.operationType match {
case OperationType.Query => queryExecution
case OperationType.Mutation => QueryExecution.Sequential
case OperationType.Subscription => QueryExecution.Sequential
val executionMode = queryExecution match {
case QueryExecution.Sequential => 0
case QueryExecution.Parallel => 1
case QueryExecution.Batched => 2
}

def collectAll[E, A, B, Coll[+V] <: Iterable[V]](
in: Coll[A]
in: Coll[A],
isTopLevelField: Boolean
)(
as: A => ZQuery[R, E, B]
)(implicit bf: BuildFrom[Coll[A], B, Coll[B]]): ZQuery[R, E, Coll[B]] =
if (in.sizeCompare(1) == 0) as(in.head).map(bf.newBuilder(in).+=(_).result())
else if (isTopLevelField && isMutation) ZQuery.foreach(in)(as)
else
execution match {
case QueryExecution.Sequential => ZQuery.foreach(in)(as)
case QueryExecution.Parallel => ZQuery.foreachPar(in)(as)
case QueryExecution.Batched => ZQuery.foreachBatched(in)(as)
executionMode match {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that really make a difference to use ints? Code would be cleaner without 😄

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pattern matching on Ints in increments compiles to a tableswitch, which has O(1) performance regardless of the number of elements in the pattern matching. Plus I do think that equality on ints is less complex than on objects since there's no need for type-checking.

Having said that, we only have 3 elements, and this code is probably a millionth of the overall execution complexity so I'll just go back to making it more readable 😅

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the change. I placed Sequential last now though since I find it very hard to believe anyone would ever want to use that.. right?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 😄

case 0 => ZQuery.foreach(in)(as)
case 1 => ZQuery.foreachPar(in)(as)
case 2 => ZQuery.foreachBatched(in)(as)
}

def reduceStep(
Expand Down Expand Up @@ -178,9 +181,12 @@ object Executor {
else q.map((name, _))
}

def makeObjectQuery(steps: List[(String, ReducedStep[R], FieldInfo)]) = {
def makeObjectQuery(
steps: List[(String, ReducedStep[R], FieldInfo)],
isTopLevelField: Boolean
): ExecutionQuery[ResponseValue] = {
def collectAllQueries() =
collectAll(steps)((objectFieldQuery _).tupled).map(ObjectValue.apply)
collectAll(steps, isTopLevelField)((objectFieldQuery _).tupled).map(ObjectValue.apply)

def collectMixed() = {
val resolved = ListBuffer.empty[(String, ResponseValue)]
Expand All @@ -196,7 +202,7 @@ object Executor {
remaining = remaining.tail
}

collectAll(queries.result())((objectFieldQuery _).tupled).map { results =>
collectAll(queries.result(), isTopLevelField)((objectFieldQuery _).tupled).map { results =>
var i = -1
ObjectValue(resolved.mapInPlace {
case null => i += 1; results(i)
Expand All @@ -209,15 +215,15 @@ object Executor {
else collectMixed()
}

def makeListQuery(steps: List[ReducedStep[R]], areItemsNullable: Boolean) =
collectAll(steps)(if (areItemsNullable) loop(_).catchAll(handleError) else loop)
def makeListQuery(steps: List[ReducedStep[R]], areItemsNullable: Boolean): ExecutionQuery[ResponseValue] =
collectAll(steps, isTopLevelField = false)(if (areItemsNullable) loop(_).catchAll(handleError) else loop(_))
.map(ListValue.apply)

def loop(step: ReducedStep[R]): ExecutionQuery[ResponseValue] =
def loop(step: ReducedStep[R], isTopLevelField: Boolean = false): ExecutionQuery[ResponseValue] =
step match {
case PureStep(value) => ZQuery.succeed(value)
case ReducedStep.QueryStep(step) => step.flatMap(loop)
case ReducedStep.ObjectStep(steps) => makeObjectQuery(steps)
case ReducedStep.QueryStep(step) => step.flatMap(loop(_))
case ReducedStep.ObjectStep(steps) => makeObjectQuery(steps, isTopLevelField)
case ReducedStep.ListStep(steps, areItemsNullable) => makeListQuery(steps, areItemsNullable)
case ReducedStep.StreamStep(stream) =>
ZQuery
Expand All @@ -235,7 +241,7 @@ object Executor {
}
ZQuery.fromZIO(deferred.update(deferredSteps ::: _)) *> loop(obj)
}
loop(step).catchAll(handleError)
loop(step, isTopLevelField = true).catchAll(handleError)
}

def runQuery(step: ReducedStep[R], cache: Cache) =
Expand Down
58 changes: 53 additions & 5 deletions core/src/test/scala/caliban/execution/ExecutionSpec.scala
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
package caliban.execution

import java.util.UUID
import caliban.CalibanError.ExecutionError
import caliban.Macros.gqldoc
import caliban.TestUtils._
import caliban.Value.{ BooleanValue, IntValue, NullValue, StringValue }
import caliban._
import caliban.introspection.adt.__Type
import caliban.parsing.adt.LocationInfo
import caliban.schema.Annotations.{ GQLInterface, GQLName, GQLValueType }
import caliban.schema._
import caliban.schema.Schema.auto._
import caliban.schema.ArgBuilder.auto._
import caliban._
import zio.{ FiberRef, IO, Task, UIO, ZIO, ZLayer }
import caliban.schema.Schema.auto._
import caliban.schema._
import zio._
import zio.stream.ZStream
import zio.test._

import java.util.UUID

object ExecutionSpec extends ZIOSpecDefault {

@GQLInterface
Expand Down Expand Up @@ -1319,6 +1320,53 @@ object ExecutionSpec extends ZIOSpecDefault {
)
)

},
test("top-level fields are executed sequentially for mutations") {
case class Foo(field1: UIO[Unit], field2: UIO[Unit])
case class Mutations(
mutation1: CharacterArgs => UIO[Foo],
mutation2: CharacterArgs => UIO[Foo]
)

val ref = Unsafe.unsafe(implicit u => Ref.unsafe.make(List.empty[String]))
def add(name: String, d: Duration = 1.second) = ref.update(name :: _).delay(d)
def foo(prefix: String) = ZIO.succeed(Foo(add(s"$prefix-f1", 1500.millis), add(s"$prefix-f2", 2.seconds)))

val interpreter = graphQL(
RootResolver(
resolverIO.queryResolver,
Mutations(
_ => add("m1") *> foo("m1"),
_ => add("m2") *> foo("m2")
)
)
).interpreter

def adjustAndGet(d: Duration = 1.second) = TestClock.adjust(d) *> ref.get

for {
i <- interpreter
_ <- i.execute(gqldoc("""mutation {
mutation1(name: "foo") { field1 field2 }
mutation2(name: "bar") { field1 field2 }
}"""))
.fork
r1 <- ref.get
r2 <- adjustAndGet()
r3 <- adjustAndGet()
r4 <- adjustAndGet()
r5 <- adjustAndGet()
r6 <- adjustAndGet()
r7 <- adjustAndGet()
} yield assertTrue(
r1 == Nil,
r2 == List("m1"),
r3 == List("m1"),
r4 == List("m1-f2", "m1-f1", "m1"),
r5 == List("m2", "m1-f2", "m1-f1", "m1"),
r6 == List("m2", "m1-f2", "m1-f1", "m1"),
r7 == List("m2-f2", "m2-f1", "m2", "m1-f2", "m1-f1", "m1")
)
}
)
}