Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schema transformations #2218

Merged
merged 8 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions core/src/main/scala/caliban/GraphQL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import caliban.parsing.adt.{ Directive, Document, OperationType }
import caliban.parsing.{ Parser, SourceMapper, VariablesCoercer }
import caliban.rendering.DocumentRenderer
import caliban.schema._
import caliban.transformers.Transformer
import caliban.validation.Validator
import caliban.wrappers.Wrapper
import caliban.wrappers.Wrapper._
Expand All @@ -28,6 +29,7 @@ trait GraphQL[-R] { self =>
protected val wrappers: List[Wrapper[R]]
protected val additionalDirectives: List[__Directive]
protected val features: Set[Feature]
protected val transformer: Transformer[R]

private[caliban] def validateRootSchema(implicit trace: Trace): IO[ValidationError, RootSchema[R]] =
ZIO.fromEither(Validator.validateSchemaEither(schemaBuilder))
Expand Down Expand Up @@ -165,7 +167,8 @@ trait GraphQL[-R] { self =>
fieldWrappers,
config.queryExecution,
features,
config.queryCache
config.queryCache,
transformer
)
}
}
Expand Down Expand Up @@ -197,10 +200,11 @@ trait GraphQL[-R] { self =>
*/
final def withWrapper[R2 <: R](wrapper: Wrapper[R2]): GraphQL[R2] =
new GraphQL[R2] {
override val schemaBuilder: RootSchemaBuilder[R2] = self.schemaBuilder
override val wrappers: List[Wrapper[R2]] = wrapper :: self.wrappers
override val additionalDirectives: List[__Directive] = self.additionalDirectives
override val features: Set[Feature] = self.features
override protected val schemaBuilder: RootSchemaBuilder[R2] = self.schemaBuilder
override protected val wrappers: List[Wrapper[R2]] = wrapper :: self.wrappers
override protected val additionalDirectives: List[__Directive] = self.additionalDirectives
override protected val features: Set[Feature] = self.features
override protected val transformer: Transformer[R] = self.transformer
}

/**
Expand All @@ -221,11 +225,12 @@ trait GraphQL[-R] { self =>
*/
final def combine[R1 <: R](that: GraphQL[R1]): GraphQL[R1] =
new GraphQL[R1] {
override val schemaBuilder: RootSchemaBuilder[R1] = self.schemaBuilder |+| that.schemaBuilder
override protected val schemaBuilder: RootSchemaBuilder[R1] = self.schemaBuilder |+| that.schemaBuilder
override protected val wrappers: List[Wrapper[R1]] = self.wrappers ++ that.wrappers
override protected val additionalDirectives: List[__Directive] =
self.additionalDirectives ++ that.additionalDirectives
override protected val features: Set[Feature] = self.features ++ that.features
override protected val transformer: Transformer[R1] = self.transformer |+| that.transformer
}

/**
Expand Down Expand Up @@ -259,6 +264,7 @@ trait GraphQL[-R] { self =>
override protected val wrappers: List[Wrapper[R]] = self.wrappers
override protected val additionalDirectives: List[__Directive] = self.additionalDirectives
override protected val features: Set[Feature] = self.features
override protected val transformer: Transformer[R] = self.transformer
}

/**
Expand All @@ -273,6 +279,7 @@ trait GraphQL[-R] { self =>
override protected val wrappers: List[Wrapper[R]] = self.wrappers
override protected val additionalDirectives: List[__Directive] = self.additionalDirectives
override protected val features: Set[Feature] = self.features
override protected val transformer: Transformer[R] = self.transformer
}

final def withSchemaDirectives(directives: List[Directive]): GraphQL[R] = new GraphQL[R] {
Expand All @@ -281,19 +288,34 @@ trait GraphQL[-R] { self =>
override protected val wrappers: List[Wrapper[R]] = self.wrappers
override protected val additionalDirectives: List[__Directive] = self.additionalDirectives
override protected val features: Set[Feature] = self.features
override protected val transformer: Transformer[R] = self.transformer
}

final def withAdditionalDirectives(directives: List[__Directive]): GraphQL[R] = new GraphQL[R] {
override protected val schemaBuilder: RootSchemaBuilder[R] = self.schemaBuilder
override protected val wrappers: List[Wrapper[R]] = self.wrappers
override protected val additionalDirectives: List[__Directive] = self.additionalDirectives ++ directives
override protected val features: Set[Feature] = self.features
override protected val transformer: Transformer[R] = self.transformer
}

final def enable(feature: Feature): GraphQL[R] = new GraphQL[R] {
override protected val schemaBuilder: RootSchemaBuilder[R] = self.schemaBuilder
override protected val wrappers: List[Wrapper[R]] = self.wrappers
override protected val additionalDirectives: List[__Directive] = self.additionalDirectives
override protected val features: Set[Feature] = self.features + feature
override protected val transformer: Transformer[R] = self.transformer
}

/**
* Transforms the schema using the given transformer.
* This can be used to rename or filter types, fields and arguments.
*/
final def transform[R1 <: R](t: Transformer[R1]): GraphQL[R1] = new GraphQL[R1] {
override protected val schemaBuilder: RootSchemaBuilder[R1] = self.schemaBuilder.visit(t.typeVisitor)
override protected val wrappers: List[Wrapper[R1]] = self.wrappers
override protected val additionalDirectives: List[__Directive] = self.additionalDirectives
override protected val features: Set[Feature] = self.features
override protected val transformer: Transformer[R1] = self.transformer |+| t
}
}
20 changes: 12 additions & 8 deletions core/src/main/scala/caliban/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import caliban.parsing.adt._
import caliban.schema.ReducedStep.DeferStep
import caliban.schema.Step.{ PureStep => _, _ }
import caliban.schema.{ PureStep, ReducedStep, Step, Types }
import caliban.transformers.Transformer
import caliban.wrappers.Wrapper.FieldWrapper
import zio._
import zio.query._
Expand Down Expand Up @@ -38,11 +39,13 @@ object Executor {
fieldWrappers: List[FieldWrapper[R]] = Nil,
queryExecution: QueryExecution = QueryExecution.Parallel,
featureSet: Set[Feature] = Set.empty,
makeCache: UIO[Cache] = Cache.empty(Trace.empty)
makeCache: UIO[Cache] = Cache.empty(Trace.empty),
transformer: Transformer[R] = Transformer.empty[R]
)(implicit trace: Trace): URIO[R, GraphQLResponse[CalibanError]] = {
val wrapPureValues = fieldWrappers.exists(_.wrapPureValues)
val stepReducer =
new StepReducer[R](
transformer,
request.operationType eq OperationType.Subscription,
featureSet(Feature.Defer),
wrapPureValues
Expand Down Expand Up @@ -123,6 +126,7 @@ object Executor {
ZIO.succeed(GraphQLResponse(NullValue, List(error)))

private final class StepReducer[R](
transformer: Transformer[R],
isSubscription: Boolean,
isDeferredEnabled: Boolean,
wrapPureValues: Boolean
Expand Down Expand Up @@ -259,13 +263,13 @@ object Executor {
catch { case NonFatal(e) => Step.fail(e) }

step match {
case s: PureStep => s
case QueryStep(inner) => reduceQuery(inner)
case ObjectStep(objectName, fields) => reduceObjectStep(objectName, fields)
case FunctionStep(step) => reduceStep(wrapFn(step, arguments), currentField, Map.empty, path)
case MetadataFunctionStep(step) => reduceStep(wrapFn(step, currentField), currentField, arguments, path)
case ListStep(steps) => reduceListStep(steps)
case StreamStep(stream) => reduceStream(stream)
case s: PureStep => s
case s: QueryStep[R] => reduceQuery(s.query)
case s: ObjectStep[R] => val t = transformer(s, currentField); reduceObjectStep(t.name, t.fields)
case s: FunctionStep[R] => reduceStep(wrapFn(s.step, arguments), currentField, Map.empty, path)
case s: MetadataFunctionStep[R] => reduceStep(wrapFn(s.step, currentField), currentField, arguments, path)
case s: ListStep[R] => reduceListStep(s.steps)
case s: StreamStep[R] => reduceStream(s.inner)
}
}

Expand Down
104 changes: 104 additions & 0 deletions core/src/main/scala/caliban/introspection/adt/__Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,107 @@ case class __Type(
case _ => Set.empty
}
}

sealed trait TypeVisitor { self =>
import TypeVisitor._

def |+|(that: TypeVisitor): TypeVisitor = TypeVisitor.Combine(self, that)

def visit(t: __Type): __Type = {
def collect(visitor: TypeVisitor): __Type => __Type =
visitor match {
case Empty => identity
case Modify(f) => f
case Combine(v1, v2) => collect(v1) andThen collect(v2)
}

val f = collect(self)

def loop(t: __Type): __Type =
f(
t.copy(
fields = t.fields(_).map(_.map(field => field.copy(`type` = () => loop(field.`type`())))),
inputFields = t.inputFields(_).map(_.map(field => field.copy(`type` = () => loop(field.`type`())))),
interfaces = () => t.interfaces().map(_.map(loop)),
possibleTypes = t.possibleTypes.map(_.map(loop)),
ofType = t.ofType.map(loop)
)
)

loop(t)
}
}

object TypeVisitor {
private case object Empty extends TypeVisitor {
override def |+|(that: TypeVisitor): TypeVisitor = that
override def visit(t: __Type): __Type = t
}
private case class Modify(f: __Type => __Type) extends TypeVisitor
private case class Combine(v1: TypeVisitor, v2: TypeVisitor) extends TypeVisitor

val empty: TypeVisitor = Empty
def modify(f: __Type => __Type): TypeVisitor = Modify(f)
private[caliban] def modify[A](visitor: ListVisitor[A]): TypeVisitor = modify(t => visitor.visit(t))

object fields extends ListVisitorConstructors[__Field] {
val set: __Type => (List[__Field] => List[__Field]) => __Type =
t => f => t.copy(fields = args => t.fields(args).map(f))
}
object inputFields extends ListVisitorConstructors[__InputValue] {
val set: __Type => (List[__InputValue] => List[__InputValue]) => __Type =
t => f => t.copy(inputFields = args => t.inputFields(args).map(f))
}
object enumValues extends ListVisitorConstructors[__EnumValue] {
val set: __Type => (List[__EnumValue] => List[__EnumValue]) => __Type =
t => f => t.copy(enumValues = args => t.enumValues(args).map(f))
}
object directives extends ListVisitorConstructors[Directive] {
val set: __Type => (List[Directive] => List[Directive]) => __Type =
t => f => t.copy(directives = t.directives.map(f))
}
}

private[caliban] sealed abstract class ListVisitor[A](implicit val set: __Type => (List[A] => List[A]) => __Type) {
self =>
import ListVisitor._

def visit(t: __Type): __Type =
self match {
case Filter(predicate) => set(t)(_.filter(predicate(t)))
case Modify(f) => set(t)(_.map(f(t)))
case Add(f) => set(t)(f(t).foldLeft(_) { case (as, a) => a :: as })
}
}

private[caliban] object ListVisitor {
private case class Filter[A](predicate: __Type => A => Boolean)(implicit
set: __Type => (List[A] => List[A]) => __Type
) extends ListVisitor[A]
private case class Modify[A](f: __Type => A => A)(implicit
set: __Type => (List[A] => List[A]) => __Type
) extends ListVisitor[A]
private case class Add[A](f: __Type => List[A])(implicit
set: __Type => (List[A] => List[A]) => __Type
) extends ListVisitor[A]

def filter[A](predicate: (__Type, A) => Boolean)(implicit
set: __Type => (List[A] => List[A]) => __Type
): TypeVisitor =
TypeVisitor.modify(Filter[A](t => field => predicate(t, field)))
def modify[A](f: (__Type, A) => A)(implicit set: __Type => (List[A] => List[A]) => __Type): TypeVisitor =
TypeVisitor.modify(Modify[A](t => field => f(t, field)))
def add[A](f: __Type => List[A])(implicit set: __Type => (List[A] => List[A]) => __Type): TypeVisitor =
TypeVisitor.modify(Add(f))
}

private[caliban] trait ListVisitorConstructors[A] {
implicit val set: __Type => (List[A] => List[A]) => __Type

def filter(predicate: A => Boolean): TypeVisitor = filterWith((_, a) => predicate(a))
def filterWith(predicate: (__Type, A) => Boolean): TypeVisitor = ListVisitor.filter(predicate)
def modify(f: A => A): TypeVisitor = modifyWith((_, a) => f(a))
def modifyWith(f: (__Type, A) => A): TypeVisitor = ListVisitor.modify(f)
def add(list: List[A]): TypeVisitor = addWith(_ => list)
def addWith(f: __Type => List[A]): TypeVisitor = ListVisitor.add(f)
}
10 changes: 6 additions & 4 deletions core/src/main/scala/caliban/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import caliban.parsing.adt.{ Directive, Document }
import caliban.rendering.DocumentRenderer
import caliban.schema.Types.collectTypes
import caliban.schema._
import caliban.transformers.Transformer
import caliban.wrappers.Wrapper

package object caliban {
Expand All @@ -26,7 +27,7 @@ package object caliban {
mutationSchema: Schema[R, M],
subscriptionSchema: Schema[R, S]
): GraphQL[R] = new GraphQL[R] {
val schemaBuilder: RootSchemaBuilder[R] = RootSchemaBuilder(
override protected val schemaBuilder: RootSchemaBuilder[R] = RootSchemaBuilder(
resolver.queryResolver.map(r => Operation(querySchema.toType_(), querySchema.resolve(r))),
resolver.mutationResolver.map(r => Operation(mutationSchema.toType_(), mutationSchema.resolve(r))),
resolver.subscriptionResolver.map(r =>
Expand All @@ -35,9 +36,10 @@ package object caliban {
schemaDirectives = schemaDirectives,
schemaDescription = schemaDescription
)
val wrappers: List[Wrapper[R]] = Nil
val additionalDirectives: List[__Directive] = directives
val features: Set[Feature] = Set.empty
override protected val wrappers: List[Wrapper[R]] = Nil
override protected val additionalDirectives: List[__Directive] = directives
override protected val features: Set[Feature] = Set.empty
override protected val transformer: Transformer[R] = Transformer.empty
}

/**
Expand Down
10 changes: 9 additions & 1 deletion core/src/main/scala/caliban/schema/RootSchemaBuilder.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package caliban.schema

import caliban.introspection.adt.__Type
import caliban.introspection.adt.{ __Type, TypeVisitor }
import caliban.parsing.adt.Directive
import caliban.schema.Types.collectTypes

Expand Down Expand Up @@ -32,4 +32,12 @@ case class RootSchemaBuilder[-R](
.flatMap(_._2.headOption)
.toList
}

def visit(visitor: TypeVisitor): RootSchemaBuilder[R] =
copy(
query = query.map(query => query.copy(opType = visitor.visit(query.opType))),
mutation = mutation.map(mutation => mutation.copy(opType = visitor.visit(mutation.opType))),
subscription = subscription.map(subscription => subscription.copy(opType = visitor.visit(subscription.opType))),
additionalTypes = additionalTypes.map(visitor.visit)
)
}
Loading