Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transformer that excludes fields / inputs based on directives #2293

Merged
merged 11 commits into from
Jun 27, 2024
10 changes: 8 additions & 2 deletions core/src/main/scala-2/caliban/schema/SchemaDerivation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ trait CommonSchemaDerivation[R] {
p.annotations.collectFirst { case GQLDeprecated(_) => () }.isDefined,
p.annotations.collectFirst { case GQLDeprecated(reason) => reason },
Some(p.annotations.collect { case GQLDirective(dir) => dir }.toList).filter(_.nonEmpty),
() => Some(tpe)
() => Some(tpe),
getTags(p.annotations)
)
)
.toList,
Expand Down Expand Up @@ -122,7 +123,8 @@ trait CommonSchemaDerivation[R] {
Some(SchemaUtils.SemanticNonNull)
else None
}
).filter(_.nonEmpty)
).filter(_.nonEmpty),
getTags(p.annotations)
)
}
.toList,
Expand Down Expand Up @@ -286,6 +288,10 @@ trait CommonSchemaDerivation[R] {

private def getDescription[Typeclass[_], Type](ctx: ReadOnlyParam[Typeclass, Type]): Option[String] =
getDescription(ctx.annotations)

private def getTags(annotations: Seq[Any]): Set[String] =
annotations.collect { case GQLTag(tags @ _*) => tags }.flatten.toSet

}

trait SchemaDerivation[R] extends CommonSchemaDerivation[R] {
Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala-3/caliban/schema/DerivationUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ private object DerivationUtils {
def getDeprecatedReason(annotations: Seq[Any]): Option[String] =
annotations.collectFirst { case GQLDeprecated(reason) => reason }

def getTags(annotations: Seq[Any]): Set[String] =
annotations.collect { case GQLTag(dir) => dir }.toSet

def mkEnum(annotations: List[Any], info: TypeInfo, subTypes: List[(String, __Type, List[Any])]): __Type =
makeEnum(
Some(getName(annotations, info)),
Expand Down Expand Up @@ -113,7 +116,8 @@ private object DerivationUtils {
isDeprecated = deprecationReason.isDefined,
deprecationReason = deprecationReason,
directives = Some(getDirectives(fieldAnnotations)).filter(_.nonEmpty),
parentType = () => Some(tpe)
parentType = () => Some(tpe),
getTags(p.annotations)
)
},
Some(info.full),
Expand Down Expand Up @@ -168,7 +172,8 @@ private object DerivationUtils {
if (enableSemanticNonNull && isSemanticNonNull) Some(SchemaUtils.SemanticNonNull)
else None
}
).filter(_.nonEmpty)
).filter(_.nonEmpty),
getTags(fieldAnnotations)
)
},
getDirectives(annotations),
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/caliban/introspection/adt/__Field.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ case class __Field(
`type`: () => __Type,
isDeprecated: Boolean = false,
deprecationReason: Option[String] = None,
@GQLExcluded directives: Option[List[Directive]] = None
@GQLExcluded directives: Option[List[Directive]] = None,
@GQLExcluded tags: Set[String] = Set.empty
kyri-petrou marked this conversation as resolved.
Show resolved Hide resolved
) {
final override lazy val hashCode: Int = super.hashCode()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ case class __InputValue(
isDeprecated: Boolean = false,
deprecationReason: Option[String] = None,
@GQLExcluded directives: Option[List[Directive]] = None,
@GQLExcluded parentType: () => Option[__Type] = () => None
@GQLExcluded parentType: () => Option[__Type] = () => None,
@GQLExcluded tags: Set[String] = Set.empty
) {
def toInputValueDefinition: InputValueDefinition = {
val default = defaultValue.flatMap(v => Parser.parseInputValue(v).toOption)
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala/caliban/schema/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,9 @@ object Annotations extends AnnotationsVersionSpecific {
*/
case class GQLOneOfInput() extends StaticAnnotation

/**
* Compile-time annotation that can be used in conjunction with [[caliban.transformers.Transformer]] to
* customize schema generation.
*/
case class GQLTag(tags: String*) extends StaticAnnotation
}
6 changes: 4 additions & 2 deletions core/src/main/scala/caliban/schema/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ object Types {
`type`: () => __Type,
isDeprecated: Boolean = false,
deprecationReason: Option[String] = None,
directives: Option[List[Directive]] = None
directives: Option[List[Directive]] = None,
tags: Set[String] = Set.empty
): __Field =
__Field(
name,
Expand All @@ -79,7 +80,8 @@ object Types {
`type`,
isDeprecated,
deprecationReason,
directives
directives,
tags
)

def makeInputObject(
Expand Down
59 changes: 56 additions & 3 deletions core/src/main/scala/caliban/transformers/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ abstract class Transformer[-R] { self =>
* Set of type names that this transformer applies to.
* Needed for applying optimizations when combining transformers.
*/
protected val typeNames: collection.Set[String]
protected def typeNames: collection.Set[String]

protected def transformStep[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1]

Expand Down Expand Up @@ -326,20 +326,73 @@ object Transformer {
}
}

object ExcludeTags {

/**
* A transformer that allows excluding tagged fields and input arguments.
*
* {{{
* ExcludeTags("TagA", "TagB")
* }}}
*
* @param f tuples in the format of `(TypeName -> fieldToBeExcluded)`
*/
def apply(f: String*): Transformer[Any] =
if (f.isEmpty) Empty else new ExcludeTags(f.toSet)
}

final private class ExcludeTags(tags: Set[String]) extends Transformer[Any] {
private val map: mutable.HashMap[String, Set[String]] = mutable.HashMap.empty

private def shouldKeep(tpe: __Type, field: __Field): Boolean = {
val keep = field.tags.intersect(tags).isEmpty
if (!keep) map.updateWith(tpe.name.getOrElse("")) {
case Some(set) => Some(set + field.name)
case None => Some(Set(field.name))
}
keep
}

val typeVisitor: TypeVisitor =
TypeVisitor.fields.filterWith((t, field) => shouldKeep(t, field)) |+|
TypeVisitor.fields.modify { field =>
def loop(arg: __InputValue): Option[__InputValue] =
if (arg._type.isNullable && arg.tags.intersect(tags).nonEmpty) None
else {
lazy val newType = arg._type.mapInnerType { t =>
t.copy(inputFields = t.inputFields(_).map(_.flatMap(loop)))
}
Some(arg.copy(`type` = () => newType))
}

field.copy(args = field.args(_).flatMap(loop))
}

protected def typeNames: collection.Set[String] = map.keySet

protected def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] =
map.getOrElse(step.name, null) match {
case null => step
case excl => step.copy(fields = name => if (!excl(name)) step.fields(name) else NullStep)
}
}

final private class Combined[-R](left: Transformer[R], right: Transformer[R]) extends Transformer[R] {
val typeVisitor: TypeVisitor = left.typeVisitor |+| right.typeVisitor

protected val typeNames: mutable.HashSet[String] = {
protected def typeNames: mutable.HashSet[String] = {
val set = mutable.HashSet.from(left.typeNames)
set ++= right.typeNames
set
}

private lazy val materializedTypeNames = typeNames

protected def transformStep[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] =
right.transformStep(left.transformStep(step, field), field)

override def apply[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] =
if (typeNames(step.name)) transformStep(step, field) else step
if (materializedTypeNames(step.name)) transformStep(step, field) else step
}

private def mapFunctionStep[R](step: Step[R])(f: Map[String, InputValue] => Map[String, InputValue]): Step[R] =
Expand Down
152 changes: 151 additions & 1 deletion core/src/test/scala/caliban/transformers/TransformerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package caliban.transformers

import caliban.Macros.gqldoc
import caliban._
import caliban.schema.Annotations.GQLTag
import caliban.schema.ArgBuilder.auto._
import caliban.schema.Schema.auto._
import zio.test._
Expand Down Expand Up @@ -218,6 +219,155 @@ object TransformerSpec extends ZIOSpecDefault {
| c(arg: String!): String!
|}""".stripMargin
)
}
},
suite("ExcludeTag")(
test("fields") {
case class Query(
a: String,
@GQLTag("schemaA")
b: Int,
@GQLTag("schemaB")
c: Double,
@GQLTag("schemaA", "schemaB")
d: Boolean
)
val api: GraphQL[Any] = graphQL(RootResolver(Query("a", 2, 3d, true)))
val apiA: GraphQL[Any] = api.transform(Transformer.ExcludeTags("schemaA"))
val apiB: GraphQL[Any] = api.transform(Transformer.ExcludeTags("schemaB"))
val apiC: GraphQL[Any] = api.transform(Transformer.ExcludeTags("schemaA", "schemaB"))

for {
_ <- Configurator.setSkipValidation(true)
res0 <- api.interpreterUnsafe.execute("""{ a b c d }""").map(_.data.toString)
resA <- apiA.interpreterUnsafe.execute("""{ a b c d }""").map(_.data.toString)
resB <- apiB.interpreterUnsafe.execute("""{ a b c d }""").map(_.data.toString)
resC <- apiC.interpreterUnsafe.execute("""{ a b c d }""").map(_.data.toString)
rendered = api.render
renderedA = apiA.render
renderedB = apiB.render
renderedC = apiC.render
} yield assertTrue(
res0 == """{"a":"a","b":2,"c":3.0,"d":true}""",
resA == """{"a":"a","b":null,"c":3.0,"d":null}""",
resB == """{"a":"a","b":2,"c":null,"d":null}""",
resC == """{"a":"a","b":null,"c":null,"d":null}""",
rendered ==
"""schema {
| query: Query
|}
|
|type Query {
| a: String!
| b: Int!
| c: Float!
| d: Boolean!
|}""".stripMargin,
renderedA ==
"""schema {
| query: Query
|}
|
|type Query {
| a: String!
| c: Float!
|}""".stripMargin,
renderedB ==
"""schema {
| query: Query
|}
|
|type Query {
| a: String!
| b: Int!
|}""".stripMargin,
renderedC ==
"""schema {
| query: Query
|}
|
|type Query {
| a: String!
|}""".stripMargin
)
},
test("input fields") {
case class Nested(
a: String,
@GQLTag("schemaA")
b: Option[Int],
@GQLTag("schemaB")
c: Option[Double],
@GQLTag("schemaA", "schemaB")
d: Option[Boolean]
)
case class Args(a: String, b: String, l: List[String], nested: Nested)
case class Query(foo: Args => String)
val api: GraphQL[Any] = graphQL(RootResolver(Query(_ => "value")))
val apiA: GraphQL[Any] = api.transform(Transformer.ExcludeTags("schemaA"))
val apiB: GraphQL[Any] = api.transform(Transformer.ExcludeTags("schemaB"))
val apiC: GraphQL[Any] = api.transform(Transformer.ExcludeTags("schemaA", "schemaB"))

val rendered = api.render
val renderedA = apiA.render
val renderedB = apiB.render
val renderedC = apiC.render

assertTrue(
rendered ==
"""schema {
| query: Query
|}
|
|input NestedInput {
| a: String!
| b: Int
| c: Float
| d: Boolean
|}
|
|type Query {
| foo(a: String!, b: String!, l: [String!]!, nested: NestedInput!): String!
|}""".stripMargin,
renderedA ==
"""schema {
| query: Query
|}
|
|input NestedInput {
| a: String!
| c: Float
|}
|
|type Query {
| foo(a: String!, b: String!, l: [String!]!, nested: NestedInput!): String!
|}""".stripMargin,
renderedB ==
"""schema {
| query: Query
|}
|
|input NestedInput {
| a: String!
| b: Int
|}
|
|type Query {
| foo(a: String!, b: String!, l: [String!]!, nested: NestedInput!): String!
|}""".stripMargin,
renderedC ==
"""schema {
| query: Query
|}
|
|input NestedInput {
| a: String!
|}
|
|type Query {
| foo(a: String!, b: String!, l: [String!]!, nested: NestedInput!): String!
|}""".stripMargin
)
}
)
)
}