Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Scala3's codegen #1830

Merged
merged 2 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 58 additions & 49 deletions core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package caliban.schema

import caliban.CalibanError.ExecutionError
import caliban.InputValue
import caliban.{ CalibanError, InputValue }
import caliban.Value.*
import caliban.schema.macros.Macros
import caliban.schema.Annotations.GQLDefault
Expand Down Expand Up @@ -30,60 +30,69 @@ trait CommonArgBuilderDerivation {
inline def derived[A]: ArgBuilder[A] =
inline summonInline[Mirror.Of[A]] match {
case m: Mirror.SumOf[A] =>
lazy val subTypes = recurse[m.MirroredElemLabels, m.MirroredElemTypes]()
lazy val traitLabel = constValue[m.MirroredLabel]
new ArgBuilder[A] {
Copy link
Collaborator Author

@kyri-petrou kyri-petrou Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: Never call new ... on traits / abstract classes inside an inline method 😄

def build(input: InputValue): Either[ExecutionError, A] =
buildSum[A](subTypes, traitLabel)(input)
}
makeSumArgBuilder[A](
recurse[m.MirroredElemLabels, m.MirroredElemTypes](),
constValue[m.MirroredLabel]
)

case m: Mirror.ProductOf[A] =>
lazy val fields = recurse[m.MirroredElemLabels, m.MirroredElemTypes]()
lazy val annotations = Macros.paramAnnotations[A].to(Map)
new ArgBuilder[A] {
def build(input: InputValue): Either[ExecutionError, A] =
buildProduct(fields, annotations)(input).map(m.fromProduct)
}
makeProductArgBuilder(
recurse[m.MirroredElemLabels, m.MirroredElemTypes](),
Macros.paramAnnotations[A].to(Map)
)(m.fromProduct)
}

private def buildSum[A](
subTypes: => List[(String, List[Any], ArgBuilder[Any])],
traitLabel: => String
)(input: InputValue) =
(input match {
case EnumValue(value) => Some(value)
case StringValue(value) => Some(value)
case _ => None
}) match {
case Some(value) =>
subTypes.find { (label, annotations, _) =>
label == value || annotations.exists { case GQLName(name) => name == value }
} match {
case Some((_, _, builder)) => builder.asInstanceOf[ArgBuilder[A]].build(InputValue.ObjectValue(Map()))
case None => Left(ExecutionError(s"Invalid value $value for trait $traitLabel"))
}
case None => Left(ExecutionError(s"Can't build a trait from input $input"))
}
private def makeSumArgBuilder[A](
_subTypes: => List[(String, List[Any], ArgBuilder[Any])],
_traitLabel: => String
) = new ArgBuilder[A] {
private lazy val subTypes = _subTypes
private lazy val traitLabel = _traitLabel
private val emptyInput = InputValue.ObjectValue(Map())

private def buildProduct(
fields: => List[(String, List[Any], ArgBuilder[Any])],
annotations: => Map[String, List[Any]]
)(input: InputValue) =
fields.map { (label, _, builder) =>
input match {
case InputValue.ObjectValue(fields) =>
val finalLabel =
annotations.getOrElse(label, Nil).collectFirst { case GQLName(name) => name }.getOrElse(label)
val default = annotations.getOrElse(label, Nil).collectFirst { case GQLDefault(v) => v }
fields.get(finalLabel).fold(builder.buildMissing(default))(builder.build)
case value => builder.build(value)
}
}.foldRight[Either[ExecutionError, Tuple]](Right(EmptyTuple)) { case (item, acc) =>
item match {
case error: Left[ExecutionError, Any] => error.asInstanceOf[Left[ExecutionError, Tuple]]
case Right(value) => acc.map(value *: _)
def build(input: InputValue): Either[ExecutionError, A] =
input.match {
case EnumValue(value) => Right(value)
case StringValue(value) => Right(value)
case _ => Left(ExecutionError(s"Can't build a trait from input $input"))
}.flatMap { value =>
subTypes.collectFirst {
case (
label,
annotations,
builder: ArgBuilder[A @unchecked]
) if label == value || annotations.exists { case GQLName(name) => name == value } =>
builder
}
.toRight(ExecutionError(s"Invalid value $value for trait $traitLabel"))
.flatMap(_.build(emptyInput))
}
}
}

private def makeProductArgBuilder[A](
_fields: => List[(String, List[Any], ArgBuilder[Any])],
_annotations: => Map[String, List[Any]]
)(fromProduct: Product => A) = new ArgBuilder[A] {
private lazy val fields = _fields
private lazy val annotations = _annotations

def build(input: InputValue): Either[ExecutionError, A] =
fields.view.map { (label, _, builder) =>
input match {
case InputValue.ObjectValue(fields) =>
val labelList = annotations.get(label)
def default = labelList.flatMap(_.collectFirst { case GQLDefault(v) => v })
val finalLabel = labelList.flatMap(_.collectFirst { case GQLName(name) => name }).getOrElse(label)
fields.get(finalLabel).fold(builder.buildMissing(default))(builder.build)
case value => builder.build(value)
}
}.foldLeft[Either[ExecutionError, Tuple]](Right(EmptyTuple)) { case (acc, item) =>
item match {
case Right(value) => acc.map(_ :* value)
case Left(e) => Left(e)
}
}.map(fromProduct)
}
}

trait ArgBuilderDerivation extends CommonArgBuilderDerivation {
Expand Down
33 changes: 19 additions & 14 deletions core/src/main/scala-3/caliban/schema/SchemaDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import caliban.schema.macros.{ Macros, TypeInfo }
import scala.compiletime.*
import scala.deriving.Mirror
import scala.util.NotGiven
import scala.quoted.*

object PrintDerived {
import scala.quoted.*
Expand Down Expand Up @@ -58,24 +57,28 @@ trait CommonSchemaDerivation {
inline def derived[R, A]: Schema[R, A] =
inline summonInline[Mirror.Of[A]] match {
case m: Mirror.SumOf[A] =>
lazy val members = recurse[R, A, m.MirroredElemLabels, m.MirroredElemTypes]()()
def info = Macros.typeInfo[A]
def annotations = Macros.annotations[A]
makeSumSchema[R, A](members, info, annotations)(m)
makeSumSchema[R, A](
recurse[R, A, m.MirroredElemLabels, m.MirroredElemTypes]()(),
Macros.typeInfo[A],
Macros.annotations[A]
)(m.ordinal)

case m: Mirror.ProductOf[A] =>
lazy val fields = recurse[R, A, m.MirroredElemLabels, m.MirroredElemTypes]()()
def annotations = Macros.annotations[A]
def info = Macros.typeInfo[A]
def paramAnnotations = Macros.paramAnnotations[A].toMap
makeProductSchema[R, A](fields, info, annotations, paramAnnotations)
makeProductSchema[R, A](
recurse[R, A, m.MirroredElemLabels, m.MirroredElemTypes]()(),
Macros.typeInfo[A],
Macros.annotations[A],
Macros.paramAnnotations[A].toMap
)
}

private def makeSumSchema[R, A](
members: => List[(String, List[Any], Schema[R, Any], Int)],
_members: => List[(String, List[Any], Schema[R, Any], Int)],
info: TypeInfo,
annotations: List[Any]
)(m: Mirror.SumOf[A]): Schema[R, A] = new Schema[R, A] {
)(ordinal: A => Int): Schema[R, A] = new Schema[R, A] {

private lazy val members = _members
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to move the lazy val inside the class. lazy vals produce a fair bit of code, and creating one per schema was adding some overhead when compiling


private lazy val subTypes = members.map { case (label, subTypeAnnotations, schema, _) =>
(label, schema.toType_(), subTypeAnnotations)
Expand Down Expand Up @@ -112,18 +115,20 @@ trait CommonSchemaDerivation {
}

def resolve(value: A): Step[R] = {
val (label, _, schema, _) = members(m.ordinal(value))
val (label, _, schema, _) = members(ordinal(value))
if (isEnum) PureStep(EnumValue(label)) else schema.resolve(value)
}
}

private def makeProductSchema[R, A](
fields: => List[(String, List[Any], Schema[R, Any], Int)],
_fields: => List[(String, List[Any], Schema[R, Any], Int)],
info: TypeInfo,
annotations: List[Any],
paramAnnotations: Map[String, List[Any]]
): Schema[R, A] = new Schema[R, A] {

private lazy val fields = _fields

private lazy val isValueType: Boolean =
annotations.exists {
case GQLValueType(_) => true
Expand Down