Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Scala 3] Field derivation from case class methods #2041

Merged
merged 7 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package caliban.schema

trait AnnotationsVersionSpecific
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package caliban.schema

import scala.annotation.StaticAnnotation

trait AnnotationsVersionSpecific {

/**
* Annotation that can be used on a case class method to mark it as a GraphQL field.
* The method must be public, a `def` (does not work on `val`s / `lazy val`s) and must not take any arguments.
*
* NOTE: This annotation is not safe for use with ahead-of-time compilation (e.g., generating a GraalVM native-image executable)
*/
case class GQLField() extends StaticAnnotation

}
8 changes: 4 additions & 4 deletions core/src/main/scala-3/caliban/schema/DerivationUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ private object DerivationUtils {

def mkInputObject[R](
annotations: List[Any],
fields: List[(String, List[Any], Schema[R, Any], Int)],
fields: List[(String, List[Any], Schema[R, Any])],
info: TypeInfo
)(isInput: Boolean, isSubscription: Boolean): __Type = makeInputObject(
Some(getInputName(annotations).getOrElse(customizeInputTypeName(getName(annotations, info)))),
getDescription(annotations),
fields.map { (name, fieldAnnotations, schema, _) =>
fields.map { (name, fieldAnnotations, schema) =>
__InputValue(
name,
getDescription(fieldAnnotations),
Expand All @@ -141,12 +141,12 @@ private object DerivationUtils {

def mkObject[R](
annotations: List[Any],
fields: List[(String, List[Any], Schema[R, Any], Int)],
fields: List[(String, List[Any], Schema[R, Any])],
info: TypeInfo
)(isInput: Boolean, isSubscription: Boolean): __Type = makeObject(
Some(getName(annotations, info)),
getDescription(annotations),
fields.map { (name, fieldAnnotations, schema, _) =>
fields.map { (name, fieldAnnotations, schema) =>
val deprecatedReason = getDeprecatedReason(fieldAnnotations)
Types.makeField(
name,
Expand Down
36 changes: 27 additions & 9 deletions core/src/main/scala-3/caliban/schema/ObjectSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,50 @@ import caliban.schema.DerivationUtils.*
import magnolia1.TypeInfo

import scala.annotation.threadUnsafe
import scala.reflect.ClassTag

final private class ObjectSchema[R, A](
_fields: => List[(String, Schema[R, Any], Int)],
_constructorFields: => List[(String, Schema[R, Any], Int)],
_methodFields: => List[(String, List[Any], Schema[R, ?])],
info: TypeInfo,
anns: List[Any],
paramAnnotations: Map[String, List[Any]]
) extends Schema[R, A] {
)(using ct: ClassTag[A])
extends Schema[R, A] {

@threadUnsafe
private lazy val fields = _fields.map { (label, schema, index) =>
val fieldAnnotations = paramAnnotations.getOrElse(label, Nil)
(getName(fieldAnnotations, label), fieldAnnotations, schema, index)
private lazy val fields = {
val fromConstructor = _constructorFields.view.map { (label, schema, index) =>
val fieldAnns = paramAnnotations.getOrElse(label, Nil)
((getName(fieldAnns, label), fieldAnns, schema), Left(index))
}
val fromMethods = _methodFields.view.map { (methodName, fieldAnns, schema) =>
((getName(fieldAnns, methodName), fieldAnns, schema.asInstanceOf[Schema[R, Any]]), Right(methodName))
}

(fromConstructor ++ fromMethods).toList
}

@threadUnsafe
private lazy val resolver = {
def fs = fields.map { (name, _, schema, i) =>
name -> { (v: A) => schema.resolve(v.asInstanceOf[Product].productElement(i)) }
val clazz = ct.runtimeClass
val fs = fields.map { case ((name, _, schema), idx) =>
name ->
idx.fold(
i => (v: A) => schema.resolve(v.asInstanceOf[Product].productElement(i)),
methodName => {
val method = clazz.getMethod(methodName)
(v: A) => schema.resolve(method.invoke(v))
}
)
}
ObjectFieldResolver(getName(anns, info), fs)
}

def toType(isInput: Boolean, isSubscription: Boolean): __Type = {
val _ = resolver // Init the lazy val
if (isInput) mkInputObject[R](anns, fields, info)(isInput, isSubscription)
else mkObject[R](anns, fields, info)(isInput, isSubscription)
if (isInput) mkInputObject[R](anns, fields.map(_._1), info)(isInput, isSubscription)
else mkObject[R](anns, fields.map(_._1), info)(isInput, isSubscription)
}

def resolve(value: A): Step[R] = resolver.resolve(value)
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala-3/caliban/schema/SchemaDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import magnolia1.Macro as MagnoliaMacro

import scala.compiletime.*
import scala.deriving.Mirror
import scala.reflect.ClassTag
import scala.util.NotGiven

object PrintDerived {
Expand Down Expand Up @@ -105,10 +106,11 @@ trait CommonSchemaDerivation {
case _ =>
new ObjectSchema[R, A](
recurseProduct[R, A, m.MirroredElemLabels, m.MirroredElemTypes]()(),
Macros.fieldsFromMethods[R, A],
MagnoliaMacro.typeInfo[A],
MagnoliaMacro.anns[A],
MagnoliaMacro.paramAnns[A].toMap
)
)(using summonInline[ClassTag[A]])
}

}
Expand Down
78 changes: 75 additions & 3 deletions core/src/main/scala-3/caliban/schema/macros/Macros.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package caliban.schema.macros

import caliban.schema.Annotations.GQLExcluded
import caliban.schema.Annotations.{ GQLExcluded, GQLField }
import caliban.schema.Schema

import scala.quoted.*

Expand All @@ -12,21 +13,25 @@ object Macros {
inline def implicitExists[T]: Boolean = ${ implicitExistsImpl[T] }
inline def hasAnnotation[T, Ann]: Boolean = ${ hasAnnotationImpl[T, Ann] }

inline def fieldsFromMethods[R, T]: List[(String, List[Any], Schema[R, ?])] = ${ fieldsFromMethodsImpl[R, T] }

/**
* Tests whether type argument [[FieldT]] in [[Parent]] is annotated with [[GQLExcluded]]
*/
private def isFieldExcludedImpl[Parent: Type, FieldT: Type](using qctx: Quotes): Expr[Boolean] = {
import qctx.reflect.*
val fieldName = Type.valueOfConstant[FieldT]
val annSymbol = TypeRepr.of[GQLExcluded].typeSymbol
Expr(TypeRepr.of[Parent].typeSymbol.primaryConstructor.paramSymss.flatten.exists { v =>
fieldName.map(_ == v.name).getOrElse(false)
&& v.annotations.exists(_.tpe =:= TypeRepr.of[GQLExcluded])
&& v.getAnnotation(annSymbol).isDefined
})
}

private def hasAnnotationImpl[T: Type, Ann: Type](using qctx: Quotes): Expr[Boolean] = {
import qctx.reflect.*
Expr(TypeRepr.of[T].typeSymbol.annotations.exists(_.tpe =:= TypeRepr.of[Ann]))
val annSymbol = TypeRepr.of[Ann].typeSymbol
Expr(TypeRepr.of[T].typeSymbol.getAnnotation(annSymbol).isDefined)
}

private def implicitExistsImpl[T: Type](using q: Quotes): Expr[Boolean] = {
Expand All @@ -42,4 +47,71 @@ object Macros {
Expr(TypeRepr.of[P].typeSymbol.flags.is(Flags.Enum) && TypeRepr.of[T].typeSymbol.flags.is(Flags.Enum))
}

private def fieldsFromMethodsImpl[R: Type, T: Type](using
q: Quotes
): Expr[List[(String, List[Any], Schema[R, ?])]] = {
import q.reflect.*
val targetSym = TypeTree.of[T].symbol
val targetType = TypeRepr.of[T]
val annType = TypeRepr.of[GQLField]
val annSym = annType.typeSymbol

def summonSchema(methodSym: Symbol): Expr[Schema[R, ?]] = {
val fieldType = targetType.memberType(methodSym)
val tpe = (fieldType match {
case MethodType(_, _, returnType) => returnType
case _ => fieldType
}).widen

tpe.asType match {
case '[f] =>
Expr
.summon[Schema[R, f]]
.getOrElse(report.errorAndAbort(schemaNotFound(tpe.show)))
}
}

def checkMethodNoArgs(methodSym: Symbol): Unit =
if (methodSym.signature.paramSigs.size > 0)
report.errorAndAbort(s"Method '${methodSym.name}' annotated with @GQLField must be parameterless")

// Unfortunately we can't reuse Magnolias filtering so we copy the implementation
def filterAnnotation(ann: Term): Boolean = {
val tpe = ann.tpe

tpe != annType && // No need to include the GQLField annotation
(tpe.typeSymbol.maybeOwner.isNoSymbol ||
(tpe.typeSymbol.owner.fullName != "scala.annotation.internal" &&
tpe.typeSymbol.owner.fullName != "jdk.internal"))
}

def extractAnnotations(methodSym: Symbol): List[Expr[Any]] =
methodSym.annotations.filter(filterAnnotation).map(_.asExpr.asInstanceOf[Expr[Any]])

Expr.ofList {
targetSym.declaredMethods
.filter(_.getAnnotation(annSym).isDefined)
.map { method =>
checkMethodNoArgs(method)
'{
(
${ Expr(method.name) },
${ Expr.ofList(extractAnnotations(method)) },
${ summonSchema(method) }
)
}
}
}
}

// Copied from Schema so that we have the same compiler error message
private inline def schemaNotFound(tpe: String) =
s"""Cannot find a Schema for type $tpe.

Caliban provides instances of Schema for the most common Scala types, and can derive it for your case classes and sealed traits.
Derivation requires that you have a Schema for any other type nested inside $tpe.
If you use a custom type as an argument, you also need to provide an implicit ArgBuilder for that type.
See https://ghostdogpr.github.io/caliban/docs/schema.html for more information.
"""

}
2 changes: 1 addition & 1 deletion core/src/main/scala/caliban/schema/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import caliban.parsing.adt.Directive

import scala.annotation.StaticAnnotation

object Annotations {
object Annotations extends AnnotationsVersionSpecific {

/**
* Annotation used to indicate a type or a field is deprecated.
Expand Down
77 changes: 74 additions & 3 deletions core/src/test/scala-3/caliban/schema/Scala3DerivesSpec.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package caliban.schema

import caliban.*
import caliban.RootResolver
import caliban.schema.Annotations.GQLInterface
import caliban.schema.Annotations.{ GQLField, GQLInterface, GQLName }
import zio.test.{ assertTrue, ZIOSpecDefault }
import zio.{ RIO, Task, ZIO }

import java.time.Instant

Expand Down Expand Up @@ -174,7 +174,78 @@ object Scala3DerivesSpec extends ZIOSpecDefault {
|}""".stripMargin
)
}
)
),
suite("methods as fields") {
val expectedSchema =
"""schema {
| query: Bar
|}

|type Bar {
| foo: Foo!
|}

|type Foo {
| value: String!
| value2: String
|}""".stripMargin
List(
test("SemiAuto derivation of methods as fields") {
final case class Foo(value: String) derives Schema.SemiAuto {
def value1: String = value + 1
@GQLField def value2: Option[String] = Some(value + 2)
}
final case class Bar(foo: Foo) derives Schema.SemiAuto
val rendered = graphQL(RootResolver(Bar(Foo("foo")))).render

assertTrue(rendered == expectedSchema)
},
test("custom schema derivation") {
trait MyService
object MySchema extends SchemaDerivation[MyService]
final case class Foo(value: String) derives MySchema.SemiAuto {
@GQLField def value2: RIO[MyService, Option[String]] = ZIO.some(value + 2)
}
final case class Bar(foo: Foo) derives MySchema.SemiAuto
val rendered = graphQL(RootResolver(Bar(Foo("foo")))).render

assertTrue(rendered == expectedSchema)
},
test("method annotations") {
final case class Foo(value: String) derives Schema.SemiAuto {
@GQLField
@GQLName("value2")
def foo: Option[String] = Some(value + 2)
}
final case class Bar(foo: Foo) derives Schema.SemiAuto
val rendered = graphQL(RootResolver(Bar(Foo("foo")))).render

assertTrue(rendered == expectedSchema)
},
test("Auto derivation of methods as fields") {
final case class Foo(value: String) {
@GQLField def value2: Option[String] = Some(value + 2)
}
final case class Bar(foo: Foo) derives Schema.Auto
val rendered = graphQL(RootResolver(Bar(Foo("foo")))).render
assertTrue(rendered == expectedSchema)
},
test("execution of methods as fields") {
final case class Foo(value: String) derives Schema.SemiAuto {
@GQLField def value2: Task[String] = ZIO.succeed(value + 2)
}
final case class Bar(foo: Foo) derives Schema.SemiAuto
val gql = graphQL(RootResolver(Bar(Foo("foo"))))

gql.interpreter.flatMap { i =>
i.execute("{foo {value value2}}").map { v =>
val s = v.data.toString
assertTrue(s == """{"foo":{"value":"foo","value2":"foo2"}}""")
}
}
}
)
}
)
}
}