From 078c28de39cc9fd9cdc4eb03a9aed014c63ae4b5 Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Mon, 11 Dec 2023 08:21:48 +0100 Subject: [PATCH] Object query & Scla 3 optimizations (#2033) --- .../caliban/schema/ArgBuilderDerivation.scala | 12 ++- .../scala-3/caliban/schema/ObjectSchema.scala | 4 + .../scala-3/caliban/schema/SumSchema.scala | 9 ++- .../caliban/schema/ValueTypeSchema.scala | 10 ++- .../scala/caliban/execution/Executor.scala | 78 +++++++++++-------- 5 files changed, 69 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala b/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala index 96693bfb49..bcd943b927 100644 --- a/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala +++ b/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala @@ -67,11 +67,10 @@ trait CommonArgBuilderDerivation { private def makeSumArgBuilder[A]( _subTypes: => List[(String, List[Any], ArgBuilder[Any])], - _traitLabel: => String + traitLabel: String ) = new ArgBuilder[A] { - private lazy val subTypes = _subTypes - private lazy val traitLabel = _traitLabel - private val emptyInput = InputValue.ObjectValue(Map.empty) + private lazy val subTypes = _subTypes + private val emptyInput = InputValue.ObjectValue(Map.empty) def build(input: InputValue): Either[ExecutionError, A] = input.match { @@ -94,10 +93,9 @@ trait CommonArgBuilderDerivation { private def makeProductArgBuilder[A]( _fields: => List[(String, ArgBuilder[Any])], - _annotations: => Map[String, List[Any]] + annotations: Map[String, List[Any]] )(fromProduct: Product => A) = new ArgBuilder[A] { - private lazy val fields = _fields - private lazy val annotations = _annotations + private lazy val fields = _fields def build(input: InputValue): Either[ExecutionError, A] = fields.view.map { (label, builder) => diff --git a/core/src/main/scala-3/caliban/schema/ObjectSchema.scala b/core/src/main/scala-3/caliban/schema/ObjectSchema.scala index f23d249730..7af3383e42 100644 --- a/core/src/main/scala-3/caliban/schema/ObjectSchema.scala +++ b/core/src/main/scala-3/caliban/schema/ObjectSchema.scala @@ -4,6 +4,8 @@ import caliban.introspection.adt.__Type import caliban.schema.DerivationUtils.* import magnolia1.TypeInfo +import scala.annotation.threadUnsafe + final private class ObjectSchema[R, A]( _fields: => List[(String, Schema[R, Any], Int)], info: TypeInfo, @@ -11,11 +13,13 @@ final private class ObjectSchema[R, A]( paramAnnotations: Map[String, List[Any]] ) extends Schema[R, A] { + @threadUnsafe private lazy val fields = _fields.map { (label, schema, index) => val fieldAnnotations = paramAnnotations.getOrElse(label, Nil) (getName(fieldAnnotations, label), fieldAnnotations, schema, index) } + @threadUnsafe private lazy val resolver = { def fs = fields.map { (name, _, schema, i) => name -> { (v: A) => schema.resolve(v.asInstanceOf[Product].productElement(i)) } diff --git a/core/src/main/scala-3/caliban/schema/SumSchema.scala b/core/src/main/scala-3/caliban/schema/SumSchema.scala index 3a45f6ada7..d0ab09a1bc 100644 --- a/core/src/main/scala-3/caliban/schema/SumSchema.scala +++ b/core/src/main/scala-3/caliban/schema/SumSchema.scala @@ -6,6 +6,8 @@ import caliban.schema.DerivationUtils.* import caliban.schema.Types.makeUnion import magnolia1.TypeInfo +import scala.annotation.threadUnsafe + final private class SumSchema[R, A]( _members: => (List[(String, __Type, List[Any])], List[Schema[R, Any]]), info: TypeInfo, @@ -13,16 +15,20 @@ final private class SumSchema[R, A]( )(ordinal: A => Int) extends Schema[R, A] { + @threadUnsafe private lazy val (subTypes, schemas) = { val (m, s) = _members (m.sortBy(_._1), s.toVector) } + @threadUnsafe private lazy val isEnum = subTypes.forall((_, t, _) => t.allFields.isEmpty && t.allInputFields.isEmpty) + private val isInterface = annotations.exists(_.isInstanceOf[GQLInterface]) private val isUnion = annotations.contains(GQLUnion()) - def toType(isInput: Boolean, isSubscription: Boolean): __Type = + def toType(isInput: Boolean, isSubscription: Boolean): __Type = { + val _ = schemas if (!isInterface && !isUnion && subTypes.nonEmpty && isEnum) mkEnum(annotations, info, subTypes) else if (!isInterface) makeUnion( @@ -36,6 +42,7 @@ final private class SumSchema[R, A]( val impl = subTypes.map(_._2.copy(interfaces = () => Some(List(toType(isInput, isSubscription))))) mkInterface(annotations, info, impl) } + } def resolve(value: A): Step[R] = schemas(ordinal(value)).resolve(value) } diff --git a/core/src/main/scala-3/caliban/schema/ValueTypeSchema.scala b/core/src/main/scala-3/caliban/schema/ValueTypeSchema.scala index 4bfdf65019..fe5b7d6b6d 100644 --- a/core/src/main/scala-3/caliban/schema/ValueTypeSchema.scala +++ b/core/src/main/scala-3/caliban/schema/ValueTypeSchema.scala @@ -6,17 +6,23 @@ import caliban.schema.DerivationUtils.* import caliban.schema.Types.makeScalar import magnolia1.TypeInfo +import scala.annotation.threadUnsafe + final private class ValueTypeSchema[R, A]( _schema: => Schema[R, Any], info: TypeInfo, anns: List[Any] ) extends Schema[R, A] { - private val name = getName(anns, info) + private val name = getName(anns, info) + + @threadUnsafe private lazy val schema = _schema - def toType(isInput: Boolean, isSubscription: Boolean): __Type = + def toType(isInput: Boolean, isSubscription: Boolean): __Type = { + val _ = schema if (anns.contains(GQLValueType(true))) makeScalar(name, getDescription(anns)) else schema.toType_(isInput, isSubscription) + } def resolve(value: A): Step[R] = schema.resolve(value.asInstanceOf[Product].productElement(0)) } diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index 4856f7a0f5..a92b67a4dc 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -15,6 +15,8 @@ import zio.query.{ Cache, UQuery, URQuery, ZQuery } import zio.stream.ZStream import scala.annotation.tailrec +import scala.collection.compat.{ BuildFrom => _, _ } +import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ object Executor { @@ -44,13 +46,18 @@ object Executor { case OperationType.Subscription => QueryExecution.Sequential } - def collectAll[In, E, A](in: List[In])(as: In => ZQuery[R, E, A]): ZQuery[R, E, List[A]] = - (in, execution) match { - case (head :: Nil, _) => as(head).map(List(_)) - case (_, QueryExecution.Sequential) => ZQuery.foreach(in)(as) - case (_, QueryExecution.Parallel) => ZQuery.foreachPar(in)(as) - case (_, QueryExecution.Batched) => ZQuery.foreachBatched(in)(as) - } + def collectAll[E, A, B, Coll[+V] <: Iterable[V]]( + in: Coll[A] + )( + as: A => ZQuery[R, E, B] + )(implicit bf: BuildFrom[Coll[A], B, Coll[B]]): ZQuery[R, E, Coll[B]] = + if (in.sizeCompare(1) == 0) as(in.head).map(bf.newBuilder(in).+=(_).result()) + else + execution match { + case QueryExecution.Sequential => ZQuery.foreach(in)(as) + case QueryExecution.Parallel => ZQuery.foreachPar(in)(as) + case QueryExecution.Batched => ZQuery.foreachBatched(in)(as) + } def reduceStep( step: Step[R], @@ -95,7 +102,7 @@ object Executor { var i = 0 val lb = List.newBuilder[ReducedStep[R]] var remaining = steps - while (!remaining.isEmpty) { + while (remaining ne Nil) { lb += reduceStep(remaining.head, currentField, arguments, Right(i) :: path) i += 1 remaining = remaining.tail @@ -172,36 +179,34 @@ object Executor { } def makeObjectQuery(steps: List[(String, ReducedStep[R], FieldInfo)]) = { - def newMap() = new java.util.HashMap[String, ResponseValue](calculateMapCapacity(steps.size)) - - var pures: java.util.HashMap[String, ResponseValue] = null - val _steps = - if (wrapPureValues) steps - else { - val queries = List.newBuilder[(String, ReducedStep[R], FieldInfo)] - var remaining = steps - while (!remaining.isEmpty) { - remaining.head match { - case (name, PureStep(value), _) => - if (pures eq null) pures = newMap() - pures.putIfAbsent(name, value) - case step => queries += step - } - remaining = remaining.tail + def collectAllQueries() = + collectAll(steps)((objectFieldQuery _).tupled).map(ObjectValue.apply) + + def collectMixed() = { + val resolved = ListBuffer.empty[(String, ResponseValue)] + val queries = Vector.newBuilder[(String, ReducedStep[R], FieldInfo)] + var remaining = steps + while (remaining ne Nil) { + remaining.head match { + case (name, PureStep(value), _) => resolved += ((name, value)) + case step => + resolved += null + queries += step } - queries.result() + remaining = remaining.tail } - // Avoids placing of var into Function1 which will convert it to ObjectRef by the Scala compiler - val resolved = pures - collectAll(_steps)((objectFieldQuery _).tupled).map { results => - if (resolved eq null) ObjectValue(results) - else { - results.foreach(kv => resolved.put(kv._1, kv._2)) - ObjectValue(steps.map { case (name, _, _) => name -> resolved.get(name) }) + collectAll(queries.result())((objectFieldQuery _).tupled).map { results => + var i = -1 + ObjectValue(resolved.mapInPlace { + case null => i += 1; results(i) + case t => t + }.result()) } } + if (wrapPureValues || !steps.exists(_._2.isPure)) collectAllQueries() + else collectMixed() } def makeListQuery(steps: List[ReducedStep[R]], areItemsNullable: Boolean) = @@ -310,7 +315,7 @@ object Executor { def haveSameCondition(head: Field, tail: List[Field]): Boolean = { val condition = head._condition var remaining = tail - while (!remaining.isEmpty) { + while (remaining ne Nil) { if (remaining.head._condition != condition) return false remaining = remaining.tail } @@ -323,7 +328,7 @@ object Executor { def mergeFields(fields: List[Field]) = { val map = new java.util.LinkedHashMap[String, Field](calculateMapCapacity(fields.size)) var remaining = fields - while (!remaining.isEmpty) { + while (remaining ne Nil) { val h = remaining.head if (matchesTypename(h)) { map.compute( @@ -388,6 +393,11 @@ object Executor { } } + private implicit class EnrichedListBufferOps[A](val lb: ListBuffer[A]) extends AnyVal { + // This method doesn't exist in Scala 2.12 so we just use `.map` for it instead + def mapInPlace[B](f: A => B): ListBuffer[B] = lb.map(f) + } + /** * The behaviour of mutable Maps (both Java and Scala) is to resize once the number of entries exceeds * the capacity * loadFactor (default of 0.75d) threshold in order to prevent hash collisions.