From 244412a9847362d0b2e518c5673932b5603fd027 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Mon, 19 Feb 2024 09:36:14 +0900 Subject: [PATCH 1/6] Add transformers --- core/src/main/scala/caliban/GraphQL.scala | 35 ++- .../scala/caliban/execution/Executor.scala | 19 +- .../caliban/introspection/adt/__Type.scala | 104 +++++++++ core/src/main/scala/caliban/package.scala | 10 +- .../caliban/schema/RootSchemaBuilder.scala | 10 +- .../caliban/transformers/Transformer.scala | 207 ++++++++++++++++++ .../transformers/TransformerSpec.scala | 154 +++++++++++++ .../scala/caliban/interop/tapir/package.scala | 2 + .../stitching/RemoteSchemaResolver.scala | 10 +- .../caliban/tools/RemoteSchemaSpec.scala | 10 +- 10 files changed, 536 insertions(+), 25 deletions(-) create mode 100644 core/src/main/scala/caliban/transformers/Transformer.scala create mode 100644 core/src/test/scala/caliban/transformers/TransformerSpec.scala diff --git a/core/src/main/scala/caliban/GraphQL.scala b/core/src/main/scala/caliban/GraphQL.scala index 9b7be2821..c55903921 100644 --- a/core/src/main/scala/caliban/GraphQL.scala +++ b/core/src/main/scala/caliban/GraphQL.scala @@ -10,6 +10,7 @@ import caliban.parsing.adt.{ Directive, Document, OperationType } import caliban.parsing.{ Parser, SourceMapper, VariablesCoercer } import caliban.rendering.DocumentRenderer import caliban.schema._ +import caliban.transformers.Transformer import caliban.validation.Validator import caliban.wrappers.Wrapper import caliban.wrappers.Wrapper._ @@ -28,6 +29,7 @@ trait GraphQL[-R] { self => protected val wrappers: List[Wrapper[R]] protected val additionalDirectives: List[__Directive] protected val features: Set[Feature] + protected val transformer: Transformer[R] private[caliban] def validateRootSchema(implicit trace: Trace): IO[ValidationError, RootSchema[R]] = ZIO.fromEither(Validator.validateSchemaEither(schemaBuilder)) @@ -142,7 +144,9 @@ trait GraphQL[-R] { self => execute = (req: ExecutionRequest) => for { queryExecution <- Configurator.configuration.map(_.queryExecution) - res <- Executor.executeRequest(req, op.plan, fieldWrappers, queryExecution, features) + res <- + Executor + .executeRequest(req, op.plan, fieldWrappers, queryExecution, features, transformer) } yield res result <- wrap(execute)(executionWrappers, executionRequest) } yield result).catchAll(Executor.fail) @@ -159,10 +163,11 @@ trait GraphQL[-R] { self => */ final def withWrapper[R2 <: R](wrapper: Wrapper[R2]): GraphQL[R2] = new GraphQL[R2] { - override val schemaBuilder: RootSchemaBuilder[R2] = self.schemaBuilder - override val wrappers: List[Wrapper[R2]] = wrapper :: self.wrappers - override val additionalDirectives: List[__Directive] = self.additionalDirectives - override val features: Set[Feature] = self.features + override protected val schemaBuilder: RootSchemaBuilder[R2] = self.schemaBuilder + override protected val wrappers: List[Wrapper[R2]] = wrapper :: self.wrappers + override protected val additionalDirectives: List[__Directive] = self.additionalDirectives + override protected val features: Set[Feature] = self.features + override protected val transformer: Transformer[R] = self.transformer } /** @@ -183,11 +188,12 @@ trait GraphQL[-R] { self => */ final def combine[R1 <: R](that: GraphQL[R1]): GraphQL[R1] = new GraphQL[R1] { - override val schemaBuilder: RootSchemaBuilder[R1] = self.schemaBuilder |+| that.schemaBuilder + override protected val schemaBuilder: RootSchemaBuilder[R1] = self.schemaBuilder |+| that.schemaBuilder override protected val wrappers: List[Wrapper[R1]] = self.wrappers ++ that.wrappers override protected val additionalDirectives: List[__Directive] = self.additionalDirectives ++ that.additionalDirectives override protected val features: Set[Feature] = self.features ++ that.features + override protected val transformer: Transformer[R1] = self.transformer |+| that.transformer } /** @@ -221,6 +227,7 @@ trait GraphQL[-R] { self => override protected val wrappers: List[Wrapper[R]] = self.wrappers override protected val additionalDirectives: List[__Directive] = self.additionalDirectives override protected val features: Set[Feature] = self.features + override protected val transformer: Transformer[R] = self.transformer } /** @@ -235,6 +242,7 @@ trait GraphQL[-R] { self => override protected val wrappers: List[Wrapper[R]] = self.wrappers override protected val additionalDirectives: List[__Directive] = self.additionalDirectives override protected val features: Set[Feature] = self.features + override protected val transformer: Transformer[R] = self.transformer } final def withSchemaDirectives(directives: List[Directive]): GraphQL[R] = new GraphQL[R] { @@ -243,6 +251,7 @@ trait GraphQL[-R] { self => override protected val wrappers: List[Wrapper[R]] = self.wrappers override protected val additionalDirectives: List[__Directive] = self.additionalDirectives override protected val features: Set[Feature] = self.features + override protected val transformer: Transformer[R] = self.transformer } final def withAdditionalDirectives(directives: List[__Directive]): GraphQL[R] = new GraphQL[R] { @@ -250,6 +259,7 @@ trait GraphQL[-R] { self => override protected val wrappers: List[Wrapper[R]] = self.wrappers override protected val additionalDirectives: List[__Directive] = self.additionalDirectives ++ directives override protected val features: Set[Feature] = self.features + override protected val transformer: Transformer[R] = self.transformer } final def enable(feature: Feature): GraphQL[R] = new GraphQL[R] { @@ -257,5 +267,18 @@ trait GraphQL[-R] { self => override protected val wrappers: List[Wrapper[R]] = self.wrappers override protected val additionalDirectives: List[__Directive] = self.additionalDirectives override protected val features: Set[Feature] = self.features + feature + override protected val transformer: Transformer[R] = self.transformer + } + + /** + * Transforms the schema using the given transformer. + * This can be used to rename or filter types, fields and arguments. + */ + final def transform[R1 <: R](t: Transformer[R1]): GraphQL[R1] = new GraphQL[R1] { + override protected val schemaBuilder: RootSchemaBuilder[R1] = self.schemaBuilder.visit(t.typeVisitor) + override protected val wrappers: List[Wrapper[R1]] = self.wrappers + override protected val additionalDirectives: List[__Directive] = self.additionalDirectives + override protected val features: Set[Feature] = self.features + override protected val transformer: Transformer[R1] = self.transformer |+| t } } diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index 8eb1482ee..985265b13 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -9,6 +9,7 @@ import caliban.parsing.adt._ import caliban.schema.ReducedStep.DeferStep import caliban.schema.Step._ import caliban.schema.{ ReducedStep, Step, Types } +import caliban.transformers.Transformer import caliban.wrappers.Wrapper.FieldWrapper import zio._ import zio.query._ @@ -36,12 +37,14 @@ object Executor { plan: Step[R], fieldWrappers: List[FieldWrapper[R]] = Nil, queryExecution: QueryExecution = QueryExecution.Parallel, - featureSet: Set[Feature] = Set.empty + featureSet: Set[Feature] = Set.empty, + transformer: Transformer[R] = Transformer.empty )(implicit trace: Trace): URIO[R, GraphQLResponse[CalibanError]] = { - val wrapPureValues = fieldWrappers.exists(_.wrapPureValues) - val isDeferredEnabled = featureSet(Feature.Defer) - val isMutation = request.operationType == OperationType.Mutation - val isSubscription = request.operationType == OperationType.Subscription + val wrapPureValues = fieldWrappers.exists(_.wrapPureValues) + val isDeferredEnabled = featureSet(Feature.Defer) + val isMutation = request.operationType == OperationType.Mutation + val isSubscription = request.operationType == OperationType.Subscription + val isEmptyTransformer = transformer.isEmpty type ExecutionQuery[+A] = ZQuery[R, ExecutionError, A] @@ -137,7 +140,11 @@ object Executor { try step catch { case NonFatal(e) => Step.fail(e) } - step match { + val step0 = + if (isEmptyTransformer) step + else transformer.transformStep.lift((step, currentField)).getOrElse(step) + + step0 match { case s @ PureStep(EnumValue(v)) => // special case of an hybrid union containing case objects, those should return an object instead of a string currentField.fields.view.filter(_._condition.forall(_.contains(v))).collectFirst { diff --git a/core/src/main/scala/caliban/introspection/adt/__Type.scala b/core/src/main/scala/caliban/introspection/adt/__Type.scala index 32a4b17bb..4764f1362 100644 --- a/core/src/main/scala/caliban/introspection/adt/__Type.scala +++ b/core/src/main/scala/caliban/introspection/adt/__Type.scala @@ -150,3 +150,107 @@ case class __Type( case _ => Set.empty } } + +sealed trait TypeVisitor { self => + import TypeVisitor._ + + def |+|(that: TypeVisitor): TypeVisitor = TypeVisitor.Combine(self, that) + + def visit(t: __Type): __Type = { + def collect(visitor: TypeVisitor): __Type => __Type = + visitor match { + case Empty => identity + case Modify(f) => f + case Combine(v1, v2) => collect(v1) andThen collect(v2) + } + + val f = collect(self) + + def loop(t: __Type): __Type = + f( + t.copy( + fields = t.fields(_).map(_.map(field => field.copy(`type` = () => loop(field.`type`())))), + inputFields = t.inputFields(_).map(_.map(field => field.copy(`type` = () => loop(field.`type`())))), + interfaces = () => t.interfaces().map(_.map(loop)), + possibleTypes = t.possibleTypes.map(_.map(loop)), + ofType = t.ofType.map(loop) + ) + ) + + loop(t) + } +} + +object TypeVisitor { + private case object Empty extends TypeVisitor { + override def |+|(that: TypeVisitor): TypeVisitor = that + override def visit(t: __Type): __Type = t + } + private case class Modify(f: __Type => __Type) extends TypeVisitor + private case class Combine(v1: TypeVisitor, v2: TypeVisitor) extends TypeVisitor + + val empty: TypeVisitor = Empty + def modify(f: __Type => __Type): TypeVisitor = Modify(f) + private[caliban] def modify[A](visitor: ListVisitor[A]): TypeVisitor = modify(t => visitor.visit(t)) + + object fields extends ListVisitorConstructors[__Field] { + val set: __Type => (List[__Field] => List[__Field]) => __Type = + t => f => t.copy(fields = args => t.fields(args).map(f)) + } + object inputFields extends ListVisitorConstructors[__InputValue] { + val set: __Type => (List[__InputValue] => List[__InputValue]) => __Type = + t => f => t.copy(inputFields = args => t.inputFields(args).map(f)) + } + object enumValues extends ListVisitorConstructors[__EnumValue] { + val set: __Type => (List[__EnumValue] => List[__EnumValue]) => __Type = + t => f => t.copy(enumValues = args => t.enumValues(args).map(f)) + } + object directives extends ListVisitorConstructors[Directive] { + val set: __Type => (List[Directive] => List[Directive]) => __Type = + t => f => t.copy(directives = t.directives.map(f)) + } +} + +private[caliban] sealed abstract class ListVisitor[A](implicit val set: __Type => (List[A] => List[A]) => __Type) { + self => + import ListVisitor._ + + def visit(t: __Type): __Type = + self match { + case Filter(predicate) => set(t)(_.filter(predicate(t))) + case Modify(f) => set(t)(_.map(f(t))) + case Add(f) => set(t)(f(t).foldLeft(_) { case (as, a) => a :: as }) + } +} + +private[caliban] object ListVisitor { + private case class Filter[A](predicate: __Type => A => Boolean)(implicit + set: __Type => (List[A] => List[A]) => __Type + ) extends ListVisitor[A] + private case class Modify[A](f: __Type => A => A)(implicit + set: __Type => (List[A] => List[A]) => __Type + ) extends ListVisitor[A] + private case class Add[A](f: __Type => List[A])(implicit + set: __Type => (List[A] => List[A]) => __Type + ) extends ListVisitor[A] + + def filter[A](predicate: (__Type, A) => Boolean)(implicit + set: __Type => (List[A] => List[A]) => __Type + ): TypeVisitor = + TypeVisitor.modify(Filter[A](t => field => predicate(t, field))) + def modify[A](f: (__Type, A) => A)(implicit set: __Type => (List[A] => List[A]) => __Type): TypeVisitor = + TypeVisitor.modify(Modify[A](t => field => f(t, field))) + def add[A](f: __Type => List[A])(implicit set: __Type => (List[A] => List[A]) => __Type): TypeVisitor = + TypeVisitor.modify(Add(f)) +} + +private[caliban] trait ListVisitorConstructors[A] { + implicit val set: __Type => (List[A] => List[A]) => __Type + + def filter(predicate: A => Boolean): TypeVisitor = filterWith((_, a) => predicate(a)) + def filterWith(predicate: (__Type, A) => Boolean): TypeVisitor = ListVisitor.filter(predicate) + def modify(f: A => A): TypeVisitor = modifyWith((_, a) => f(a)) + def modifyWith(f: (__Type, A) => A): TypeVisitor = ListVisitor.modify(f) + def add(list: List[A]): TypeVisitor = addWith(_ => list) + def addWith(f: __Type => List[A]): TypeVisitor = ListVisitor.add(f) +} diff --git a/core/src/main/scala/caliban/package.scala b/core/src/main/scala/caliban/package.scala index 75c0baaea..0109bab6b 100644 --- a/core/src/main/scala/caliban/package.scala +++ b/core/src/main/scala/caliban/package.scala @@ -6,6 +6,7 @@ import caliban.parsing.adt.{ Directive, Document } import caliban.rendering.DocumentRenderer import caliban.schema.Types.collectTypes import caliban.schema._ +import caliban.transformers.Transformer import caliban.wrappers.Wrapper package object caliban { @@ -26,7 +27,7 @@ package object caliban { mutationSchema: Schema[R, M], subscriptionSchema: Schema[R, S] ): GraphQL[R] = new GraphQL[R] { - val schemaBuilder: RootSchemaBuilder[R] = RootSchemaBuilder( + override protected val schemaBuilder: RootSchemaBuilder[R] = RootSchemaBuilder( resolver.queryResolver.map(r => Operation(querySchema.toType_(), querySchema.resolve(r))), resolver.mutationResolver.map(r => Operation(mutationSchema.toType_(), mutationSchema.resolve(r))), resolver.subscriptionResolver.map(r => @@ -35,9 +36,10 @@ package object caliban { schemaDirectives = schemaDirectives, schemaDescription = schemaDescription ) - val wrappers: List[Wrapper[R]] = Nil - val additionalDirectives: List[__Directive] = directives - val features: Set[Feature] = Set.empty + override protected val wrappers: List[Wrapper[R]] = Nil + override protected val additionalDirectives: List[__Directive] = directives + override protected val features: Set[Feature] = Set.empty + override protected val transformer: Transformer[R] = Transformer.empty } /** diff --git a/core/src/main/scala/caliban/schema/RootSchemaBuilder.scala b/core/src/main/scala/caliban/schema/RootSchemaBuilder.scala index 142be8c80..025d9885c 100644 --- a/core/src/main/scala/caliban/schema/RootSchemaBuilder.scala +++ b/core/src/main/scala/caliban/schema/RootSchemaBuilder.scala @@ -1,6 +1,6 @@ package caliban.schema -import caliban.introspection.adt.__Type +import caliban.introspection.adt.{ __Type, TypeVisitor } import caliban.parsing.adt.Directive import caliban.schema.Types.collectTypes @@ -32,4 +32,12 @@ case class RootSchemaBuilder[-R]( .flatMap(_._2.headOption) .toList } + + def visit(visitor: TypeVisitor): RootSchemaBuilder[R] = + copy( + query = query.map(query => query.copy(opType = visitor.visit(query.opType))), + mutation = mutation.map(mutation => mutation.copy(opType = visitor.visit(mutation.opType))), + subscription = subscription.map(subscription => subscription.copy(opType = visitor.visit(subscription.opType))), + additionalTypes = additionalTypes.map(visitor.visit) + ) } diff --git a/core/src/main/scala/caliban/transformers/Transformer.scala b/core/src/main/scala/caliban/transformers/Transformer.scala new file mode 100644 index 000000000..e85669296 --- /dev/null +++ b/core/src/main/scala/caliban/transformers/Transformer.scala @@ -0,0 +1,207 @@ +package caliban.transformers + +import caliban.InputValue +import caliban.execution.Field +import caliban.introspection.adt._ +import caliban.schema.Step +import caliban.schema.Step.{ FunctionStep, MetadataFunctionStep, NullStep, ObjectStep } + +import scala.collection.mutable + +/** + * A transformer is able to modify a type, modifying its schema and the way it is resolved. + */ +abstract class Transformer[-R] { self => + val typeVisitor: TypeVisitor + + def transformStep[R1 <: R]: PartialFunction[(Step[R1], Field), Step[R1]] + + def |+|[R0 <: R](that: Transformer[R0]): Transformer[R0] = new Transformer[R0] { + val typeVisitor: TypeVisitor = self.typeVisitor |+| that.typeVisitor + + def transformStep[R1 <: R0]: PartialFunction[(Step[R1], Field), Step[R1]] = { case (step, field) => + val modifiedStep = self.transformStep.lift((step, field)).getOrElse(step) + that.transformStep.lift((modifiedStep, field)).getOrElse(modifiedStep) + } + } + + final def isEmpty: Boolean = self eq Transformer.Empty +} + +object Transformer { + + /** + * A transformer that does nothing. + */ + def empty[R]: Transformer[R] = Empty + + /** + * A transformer that does nothing. + */ + case object Empty extends Transformer[Any] { + val typeVisitor: TypeVisitor = TypeVisitor.empty + + def transformStep[R1]: PartialFunction[(Step[R1], Field), Step[R1]] = PartialFunction.empty + + override def |+|[R0](that: Transformer[R0]): Transformer[R0] = new Transformer[R0] { + val typeVisitor: TypeVisitor = that.typeVisitor + def transformStep[R1 <: R0]: PartialFunction[(Step[R1], Field), Step[R1]] = that.transformStep + } + } + + /** + * A transformer that allows renaming types. + * @param f a partial function that takes a type name and returns a new name for that type + */ + case class RenameType(f: PartialFunction[String, String]) extends Transformer[Any] { + private def rename(name: String): String = f.lift(name).getOrElse(name) + + val typeVisitor: TypeVisitor = + TypeVisitor.modify(t => t.copy(name = t.name.map(rename))) |+| + TypeVisitor.enumValues.modify(v => v.copy(name = rename(v.name))) + + def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (step @ ObjectStep(name, fields), _) => + f.lift(name).map(ObjectStep(_, fields)).getOrElse(step) + } + } + + /** + * A transformer that allows renaming fields. + * @param f a partial function that takes a type name and a field name and returns a new name for that field + */ + case class RenameField(f: ((String, String), String)*) extends Transformer[Any] { + private val map = mutable.HashMap.from(f) + private val inverseMap = map.map { case ((tName, fName0), fName1) => (tName, fName1) -> fName0 } + + val typeVisitor: TypeVisitor = + TypeVisitor.fields.modifyWith((t, field) => + field.copy(name = map.getOrElse((t.name.getOrElse(""), field.name), field.name)) + ) |+| + TypeVisitor.inputFields.modifyWith((t, field) => + field.copy(name = map.getOrElse((t.name.getOrElse(""), field.name), field.name)) + ) + + def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (ObjectStep(typeName, fields), _) => + ObjectStep( + typeName, + fieldName => fields(inverseMap.getOrElse((typeName, fieldName), fieldName)) + ) + } + } + + /** + * A transformer that allows renaming arguments. + * @param f a partial function that takes a type name and a field name and returns another + * partial function from an argument name to a new name for that argument + */ + case class RenameArgument( + f: PartialFunction[(String, String), (PartialFunction[String, String], PartialFunction[String, String])] + ) extends Transformer[Any] { + val typeVisitor: TypeVisitor = + TypeVisitor.fields.modifyWith((t, field) => + f.lift((t.name.getOrElse(""), field.name)) match { + case Some((rename, _)) => + field.copy(args = + field + .args(_) + .map(arg => + rename + .lift(arg.name) + .fold(arg)(newName => arg.copy(name = newName)) + ) + ) + case None => field + } + ) + + def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (ObjectStep(typeName, fields), _) => + ObjectStep( + typeName, + fieldName => { + val step = fields(fieldName) + f.lift((typeName, fieldName)) match { + case Some((_, rename)) => + mapFunctionStep(step)(_.map { case (argName, input) => rename.lift(argName).getOrElse(argName) -> input }) + case _ => step + } + } + ) + } + } + + /** + * A transformer that allows filtering fields. + * @param f a partial function that takes a type name and a field name and + * returns a boolean (true means the field should be kept) + */ + case class FilterField(f: PartialFunction[(String, String), Boolean]) extends Transformer[Any] { + val typeVisitor: TypeVisitor = + TypeVisitor.fields.filterWith((t, field) => f.lift((t.name.getOrElse(""), field.name)).getOrElse(true)) |+| + TypeVisitor.inputFields.filterWith((t, field) => f.lift((t.name.getOrElse(""), field.name)).getOrElse(true)) + + private val fnTrue = (_: (String, String)) => true + + def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (ObjectStep(typeName, fields), _) => + ObjectStep( + typeName, + fieldName => + if (f.applyOrElse((typeName, fieldName), fnTrue)) fields(fieldName) + else NullStep + ) + } + } + + /** + * A transformer that allows filtering types. + * + * @param f a partial function that takes a type name and an interface name and + * returns a boolean (true means the type should be kept) + */ + case class FilterInterface(f: PartialFunction[(String, String), Boolean]) extends Transformer[Any] { + val typeVisitor: TypeVisitor = + TypeVisitor.modify(t => + t.copy(interfaces = + () => + t.interfaces() + .map(_.filter(interface => f.lift((t.name.getOrElse(""), interface.name.getOrElse(""))).getOrElse(true))) + ) + ) + + def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (step, _) => step } + } + + /** + * A transformer that allows filtering arguments. + * @param f a partial function that takes a type name, a field name and an argument name and + * returns a boolean (true means the argument should be kept) + */ + case class FilterArgument(f: PartialFunction[(String, String, String), Boolean]) extends Transformer[Any] { + val typeVisitor: TypeVisitor = + TypeVisitor.fields.modifyWith((t, field) => + field.copy(args = + field.args(_).filter(arg => f.lift((t.name.getOrElse(""), field.name, arg.name)).getOrElse(true)) + ) + ) + + def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (ObjectStep(typeName, fields), _) => + ObjectStep( + typeName, + fieldName => + mapFunctionStep(fields(fieldName))(_.filter { case (argName, _) => + f.lift((typeName, fieldName, argName)).getOrElse(true) + }) + ) + } + } + + private def mapFunctionStep[R](step: Step[R])(f: Map[String, InputValue] => Map[String, InputValue]): Step[R] = + step match { + case FunctionStep(mapToStep) => FunctionStep(args => mapToStep(f(args))) + case MetadataFunctionStep(m) => + MetadataFunctionStep(m(_) match { + case FunctionStep(mapToStep) => FunctionStep(args => mapToStep(f(args))) + case other => other + }) + case other => other + } +} diff --git a/core/src/test/scala/caliban/transformers/TransformerSpec.scala b/core/src/test/scala/caliban/transformers/TransformerSpec.scala new file mode 100644 index 000000000..3396b994e --- /dev/null +++ b/core/src/test/scala/caliban/transformers/TransformerSpec.scala @@ -0,0 +1,154 @@ +package caliban.transformers + +import caliban._ +import caliban.schema.ArgBuilder.auto._ +import caliban.schema.Schema.auto._ +import zio.test._ + +object TransformerSpec extends ZIOSpecDefault { + + case class Args(arg: String) + case class InnerObject(b: Args => String) + case class Query(a: InnerObject) + + val api: GraphQL[Any] = graphQL(RootResolver(Query(InnerObject(_.arg)))) + + override def spec = + suite("TransformerSpec")( + test("rename type") { + val transformed: GraphQL[Any] = api.transform(Transformer.RenameType { case "InnerObject" => "Renamed" }) + val rendered = transformed.render + for { + interpreter <- transformed.interpreter + result <- interpreter.execute("""{ a { b(arg: "hello") } }""").map(_.data.toString) + } yield assertTrue( + result == """{"a":{"b":"hello"}}""", + rendered == """schema { + | query: Query + |} + | + |type Query { + | a: Renamed! + |} + | + |type Renamed { + | b(arg: String!): String! + |}""".stripMargin + ) + }, + test("rename field") { + val transformed: GraphQL[Any] = api.transform(Transformer.RenameField(("InnerObject", "b") -> "c")) + val rendered = transformed.render + for { + interpreter <- transformed.interpreter + result <- interpreter.execute("""{ a { c(arg: "hello") } }""").map(_.data.toString) + } yield assertTrue( + result == """{"a":{"c":"hello"}}""", + rendered == + """schema { + | query: Query + |} + | + |type InnerObject { + | c(arg: String!): String! + |} + | + |type Query { + | a: InnerObject! + |}""".stripMargin + ) + }, + test("rename argument") { + val transformed: GraphQL[Any] = api.transform(Transformer.RenameArgument { case ("InnerObject", "b") => + ({ case "arg" => "arg2" }, { case "arg2" => "arg" }) + }) + val rendered = transformed.render + for { + interpreter <- transformed.interpreter + result <- interpreter.execute("""{ a { b(arg2: "hello") } }""").map(_.data.toString) + } yield assertTrue( + result == """{"a":{"b":"hello"}}""", + rendered == + """schema { + | query: Query + |} + | + |type InnerObject { + | b(arg2: String!): String! + |} + | + |type Query { + | a: InnerObject! + |}""".stripMargin + ) + }, + test("filter field") { + case class Query(a: String, b: Int) + val api: GraphQL[Any] = graphQL(RootResolver(Query("a", 2))) + + val transformed: GraphQL[Any] = api.transform(Transformer.FilterField { case ("Query", "b") => false }) + val rendered = transformed.render + for { + interpreter <- transformed.interpreter + result <- interpreter.execute("""{ a }""").map(_.data.toString) + } yield assertTrue( + result == """{"a":"a"}""", + rendered == + """schema { + | query: Query + |} + | + |type Query { + | a: String! + |}""".stripMargin + ) + }, + test("filter argument") { + case class Args(arg: Option[String]) + case class Query(a: Args => String) + val api: GraphQL[Any] = graphQL(RootResolver(Query(_.arg.getOrElse("missing")))) + + val transformed: GraphQL[Any] = + api.transform(Transformer.FilterArgument { case ("Query", "a", "arg") => false }) + val rendered = transformed.render + for { + interpreter <- transformed.interpreter + result <- interpreter.execute("""{ a }""").map(_.data.toString) + } yield assertTrue( + result == """{"a":"missing"}""", + rendered == + """schema { + | query: Query + |} + | + |type Query { + | a: String! + |}""".stripMargin + ) + }, + test("combine transformers") { + val transformed: GraphQL[Any] = api + .transform(Transformer.RenameType { case "InnerObject" => "Renamed" }) + .transform(Transformer.RenameField(("Renamed", "b") -> "c")) + val rendered = transformed.render + for { + interpreter <- transformed.interpreter + result <- interpreter.execute("""{ a { c(arg: "hello") } }""").map(_.data.toString) + } yield assertTrue( + result == """{"a":{"c":"hello"}}""", + rendered == + """schema { + | query: Query + |} + | + |type Query { + | a: Renamed! + |} + | + |type Renamed { + | c(arg: String!): String! + |}""".stripMargin + ) + } + ) +} diff --git a/interop/tapir/src/main/scala/caliban/interop/tapir/package.scala b/interop/tapir/src/main/scala/caliban/interop/tapir/package.scala index b7e92669b..39cfb621b 100644 --- a/interop/tapir/src/main/scala/caliban/interop/tapir/package.scala +++ b/interop/tapir/src/main/scala/caliban/interop/tapir/package.scala @@ -12,6 +12,7 @@ import sttp.tapir.server.ServerEndpoint import sttp.tapir.{ EndpointIO, EndpointInput, EndpointOutput, PublicEndpoint } import _root_.zio.query.{ URQuery, ZQuery } import _root_.zio.{ URIO, ZIO } +import caliban.transformers.Transformer package object tapir { @@ -148,6 +149,7 @@ package object tapir { override protected val wrappers: List[Wrapper[R]] = Nil override protected val additionalDirectives: List[__Directive] = Nil override protected val features = Set.empty + override protected val transformer: Transformer[R] = Transformer.empty } private def extractPath[I](endpointName: Option[String], input: EndpointInput[I]): String = diff --git a/tools/src/main/scala/caliban/tools/stitching/RemoteSchemaResolver.scala b/tools/src/main/scala/caliban/tools/stitching/RemoteSchemaResolver.scala index 3e35e335f..fd371ef52 100644 --- a/tools/src/main/scala/caliban/tools/stitching/RemoteSchemaResolver.scala +++ b/tools/src/main/scala/caliban/tools/stitching/RemoteSchemaResolver.scala @@ -4,6 +4,7 @@ import caliban.CalibanError.ExecutionError import caliban.execution.{ Feature, Field } import caliban.introspection.adt._ import caliban.schema._ +import caliban.transformers.Transformer import caliban.{ CalibanError, GraphQL, ResponseValue } import zio._ import zio.query._ @@ -50,10 +51,11 @@ case class RemoteSchemaResolver(schema: __Schema, typeMap: Map[String, __Type]) ) new GraphQL[R] { - protected val additionalDirectives: List[__Directive] = schema.directives - protected val schemaBuilder: caliban.schema.RootSchemaBuilder[R] = builder - protected val wrappers: List[caliban.wrappers.Wrapper[R]] = List() - protected val features: Set[Feature] = Set.empty + override protected val additionalDirectives: List[__Directive] = schema.directives + override protected val schemaBuilder: caliban.schema.RootSchemaBuilder[R] = builder + override protected val wrappers: List[caliban.wrappers.Wrapper[R]] = List() + override protected val features: Set[Feature] = Set.empty + override protected val transformer: Transformer[R] = Transformer.empty } } } diff --git a/tools/src/test/scala/caliban/tools/RemoteSchemaSpec.scala b/tools/src/test/scala/caliban/tools/RemoteSchemaSpec.scala index d58e2c2c1..47f21f5f5 100644 --- a/tools/src/test/scala/caliban/tools/RemoteSchemaSpec.scala +++ b/tools/src/test/scala/caliban/tools/RemoteSchemaSpec.scala @@ -10,6 +10,7 @@ import zio.test._ import schema.Annotations._ import caliban.Macros.gqldoc import caliban.execution.Feature +import caliban.transformers.Transformer object RemoteSchemaSpec extends ZIOSpecDefault { sealed trait EnumType extends Product with Serializable @@ -98,7 +99,7 @@ object RemoteSchemaSpec extends ZIOSpecDefault { def fromRemoteSchema(s: __Schema): GraphQL[Any] = new GraphQL[Any] { - protected val schemaBuilder = + override protected val schemaBuilder = RootSchemaBuilder( query = Some( Operation[Any]( @@ -109,9 +110,10 @@ object RemoteSchemaSpec extends ZIOSpecDefault { mutation = None, subscription = None ) - protected val additionalDirectives: List[__Directive] = List() - protected val wrappers: List[caliban.wrappers.Wrapper[Any]] = List() - override protected val features: Set[Feature] = Set.empty + override protected val additionalDirectives: List[__Directive] = List() + override protected val wrappers: List[caliban.wrappers.Wrapper[Any]] = List() + override protected val features: Set[Feature] = Set.empty + override protected val transformer: Transformer[Any] = Transformer.empty } } From c6f3f47993e55f8c1a2882f81fe2a86dbef9ef97 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Tue, 20 Feb 2024 09:12:08 +1100 Subject: [PATCH 2/6] Optimize transformers --- .../scala/caliban/execution/Executor.scala | 4 +- .../caliban/transformers/Transformer.scala | 197 +++++++++++------- .../transformers/TransformerSpec.scala | 17 +- 3 files changed, 136 insertions(+), 82 deletions(-) diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index 985265b13..5ad76d4c7 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -140,9 +140,7 @@ object Executor { try step catch { case NonFatal(e) => Step.fail(e) } - val step0 = - if (isEmptyTransformer) step - else transformer.transformStep.lift((step, currentField)).getOrElse(step) + val step0 = if (isEmptyTransformer) step else transformer.transformStep(step, currentField) step0 match { case s @ PureStep(EnumValue(v)) => diff --git a/core/src/main/scala/caliban/transformers/Transformer.scala b/core/src/main/scala/caliban/transformers/Transformer.scala index e85669296..727090581 100644 --- a/core/src/main/scala/caliban/transformers/Transformer.scala +++ b/core/src/main/scala/caliban/transformers/Transformer.scala @@ -14,15 +14,13 @@ import scala.collection.mutable abstract class Transformer[-R] { self => val typeVisitor: TypeVisitor - def transformStep[R1 <: R]: PartialFunction[(Step[R1], Field), Step[R1]] + def transformStep[R1 <: R](step: Step[R1], field: Field): Step[R1] def |+|[R0 <: R](that: Transformer[R0]): Transformer[R0] = new Transformer[R0] { val typeVisitor: TypeVisitor = self.typeVisitor |+| that.typeVisitor - def transformStep[R1 <: R0]: PartialFunction[(Step[R1], Field), Step[R1]] = { case (step, field) => - val modifiedStep = self.transformStep.lift((step, field)).getOrElse(step) - that.transformStep.lift((modifiedStep, field)).getOrElse(modifiedStep) - } + def transformStep[R1 <: R0](step: Step[R1], field: Field): Step[R1] = + that.transformStep(self.transformStep(step, field), field) } final def isEmpty: Boolean = self eq Transformer.Empty @@ -41,11 +39,11 @@ object Transformer { case object Empty extends Transformer[Any] { val typeVisitor: TypeVisitor = TypeVisitor.empty - def transformStep[R1]: PartialFunction[(Step[R1], Field), Step[R1]] = PartialFunction.empty + def transformStep[R1](step: Step[R1], field: Field): Step[R1] = step override def |+|[R0](that: Transformer[R0]): Transformer[R0] = new Transformer[R0] { - val typeVisitor: TypeVisitor = that.typeVisitor - def transformStep[R1 <: R0]: PartialFunction[(Step[R1], Field), Step[R1]] = that.transformStep + val typeVisitor: TypeVisitor = that.typeVisitor + def transformStep[R1 <: R0](step: Step[R1], field: Field): Step[R1] = that.transformStep(step, field) } } @@ -53,15 +51,20 @@ object Transformer { * A transformer that allows renaming types. * @param f a partial function that takes a type name and returns a new name for that type */ - case class RenameType(f: PartialFunction[String, String]) extends Transformer[Any] { - private def rename(name: String): String = f.lift(name).getOrElse(name) + case class RenameType(private val f: (String, String)*) extends Transformer[Any] { + private val map = f.toMap + + private def rename(name: String): String = map.getOrElse(name, name) val typeVisitor: TypeVisitor = TypeVisitor.modify(t => t.copy(name = t.name.map(rename))) |+| TypeVisitor.enumValues.modify(v => v.copy(name = rename(v.name))) - def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (step @ ObjectStep(name, fields), _) => - f.lift(name).map(ObjectStep(_, fields)).getOrElse(step) + def transformStep[R](step: Step[R], field: Field): Step[R] = step match { + case ObjectStep(typeName, fields) => + val res = map.get(typeName) + if (res.isEmpty) step else ObjectStep(res.get, fields) + case _ => step } } @@ -69,23 +72,24 @@ object Transformer { * A transformer that allows renaming fields. * @param f a partial function that takes a type name and a field name and returns a new name for that field */ - case class RenameField(f: ((String, String), String)*) extends Transformer[Any] { - private val map = mutable.HashMap.from(f) - private val inverseMap = map.map { case ((tName, fName0), fName1) => (tName, fName1) -> fName0 } + case class RenameField(private val f: (String, (String, String))*) extends Transformer[Any] { + private val visitorMap = toMap2(f) + private val transformMap = swapMap2(visitorMap) - val typeVisitor: TypeVisitor = - TypeVisitor.fields.modifyWith((t, field) => - field.copy(name = map.getOrElse((t.name.getOrElse(""), field.name), field.name)) - ) |+| - TypeVisitor.inputFields.modifyWith((t, field) => - field.copy(name = map.getOrElse((t.name.getOrElse(""), field.name), field.name)) - ) + val typeVisitor: TypeVisitor = { + def _get(t: __Type, name: String) = getFromMap2(visitorMap, name)(t.name.getOrElse(""), name) - def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (ObjectStep(typeName, fields), _) => - ObjectStep( - typeName, - fieldName => fields(inverseMap.getOrElse((typeName, fieldName), fieldName)) - ) + TypeVisitor.fields.modifyWith((t, field) => field.copy(name = _get(t, field.name))) |+| + TypeVisitor.inputFields.modifyWith((t, field) => field.copy(name = _get(t, field.name))) + } + + def transformStep[R](step: Step[R], field: Field): Step[R] = step match { + case ObjectStep(typeName, fields) => + ObjectStep( + typeName, + fieldName => fields(getFromMap2(transformMap, fieldName)(typeName, fieldName)) + ) + case _ => step } } @@ -94,38 +98,42 @@ object Transformer { * @param f a partial function that takes a type name and a field name and returns another * partial function from an argument name to a new name for that argument */ - case class RenameArgument( - f: PartialFunction[(String, String), (PartialFunction[String, String], PartialFunction[String, String])] - ) extends Transformer[Any] { + case class RenameArgument(private val f: (String, (String, (String, String)))*) extends Transformer[Any] { + private val visitorMap = toMap3(f) + private val transformMap = swapMap3(visitorMap) + val typeVisitor: TypeVisitor = TypeVisitor.fields.modifyWith((t, field) => - f.lift((t.name.getOrElse(""), field.name)) match { - case Some((rename, _)) => + visitorMap.get(t.name.getOrElse("")).flatMap(_.get(field.name)) match { + case Some(rename) => field.copy(args = field .args(_) .map(arg => rename - .lift(arg.name) + .get(arg.name) .fold(arg)(newName => arg.copy(name = newName)) ) ) - case None => field + case None => field } ) - def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (ObjectStep(typeName, fields), _) => - ObjectStep( - typeName, - fieldName => { - val step = fields(fieldName) - f.lift((typeName, fieldName)) match { - case Some((_, rename)) => - mapFunctionStep(step)(_.map { case (argName, input) => rename.lift(argName).getOrElse(argName) -> input }) - case _ => step + def transformStep[R](step: Step[R], field: Field): Step[R] = step match { + case ObjectStep(typeName, fields) => + ObjectStep( + typeName, + fieldName => { + val step = fields(fieldName) + val rename = transformMap.get(typeName).flatMap(_.get(fieldName)) + if (rename.isEmpty) step + else + mapFunctionStep(step)(_.map { case (argName, input) => + rename.get.getOrElse(argName, argName) -> input + }) } - } - ) + ) + case _ => step } } @@ -134,20 +142,22 @@ object Transformer { * @param f a partial function that takes a type name and a field name and * returns a boolean (true means the field should be kept) */ - case class FilterField(f: PartialFunction[(String, String), Boolean]) extends Transformer[Any] { - val typeVisitor: TypeVisitor = - TypeVisitor.fields.filterWith((t, field) => f.lift((t.name.getOrElse(""), field.name)).getOrElse(true)) |+| - TypeVisitor.inputFields.filterWith((t, field) => f.lift((t.name.getOrElse(""), field.name)).getOrElse(true)) + case class FilterField(private val f: (String, (String, Boolean))*) extends Transformer[Any] { + private val map = toMap2[Boolean](f) - private val fnTrue = (_: (String, String)) => true + val typeVisitor: TypeVisitor = { + val _get = getFromMap2(map) _ + TypeVisitor.fields.filterWith((t, field) => _get(t.name.getOrElse(""), field.name)) |+| + TypeVisitor.inputFields.filterWith((t, field) => _get(t.name.getOrElse(""), field.name)) + } - def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (ObjectStep(typeName, fields), _) => - ObjectStep( - typeName, - fieldName => - if (f.applyOrElse((typeName, fieldName), fnTrue)) fields(fieldName) - else NullStep - ) + def transformStep[R](step: Step[R], field: Field): Step[R] = step match { + case ObjectStep(typeName, fields) => + ObjectStep( + typeName, + fieldName => if (getFromMap2(map, default = true)(typeName, fieldName)) fields(fieldName) else NullStep + ) + case _ => step } } @@ -157,17 +167,21 @@ object Transformer { * @param f a partial function that takes a type name and an interface name and * returns a boolean (true means the type should be kept) */ - case class FilterInterface(f: PartialFunction[(String, String), Boolean]) extends Transformer[Any] { + case class FilterInterface(private val f: (String, (String, Boolean))*) extends Transformer[Any] { + private val map = toMap2(f) + val typeVisitor: TypeVisitor = TypeVisitor.modify(t => t.copy(interfaces = () => t.interfaces() - .map(_.filter(interface => f.lift((t.name.getOrElse(""), interface.name.getOrElse(""))).getOrElse(true))) + .map( + _.filter(interface => getFromMap2(map)(t.name.getOrElse(""), interface.name.getOrElse(""))) + ) ) ) - def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (step, _) => step } + def transformStep[R](step: Step[R], field: Field): Step[R] = step } /** @@ -175,22 +189,28 @@ object Transformer { * @param f a partial function that takes a type name, a field name and an argument name and * returns a boolean (true means the argument should be kept) */ - case class FilterArgument(f: PartialFunction[(String, String, String), Boolean]) extends Transformer[Any] { + case class FilterArgument(private val f: (String, (String, (String, Boolean)))*) extends Transformer[Any] { + private val map = toMap3(f) + val typeVisitor: TypeVisitor = TypeVisitor.fields.modifyWith((t, field) => field.copy(args = - field.args(_).filter(arg => f.lift((t.name.getOrElse(""), field.name, arg.name)).getOrElse(true)) + field + .args(_) + .filter(arg => getFromMap3(map)(t.name.getOrElse(""), field.name, arg.name)) ) ) - def transformStep[R]: PartialFunction[(Step[R], Field), Step[R]] = { case (ObjectStep(typeName, fields), _) => - ObjectStep( - typeName, - fieldName => - mapFunctionStep(fields(fieldName))(_.filter { case (argName, _) => - f.lift((typeName, fieldName, argName)).getOrElse(true) - }) - ) + def transformStep[R](step: Step[R], field: Field): Step[R] = step match { + case ObjectStep(typeName, fields) => + ObjectStep( + typeName, + fieldName => + mapFunctionStep(fields(fieldName))(_.filter { case (argName, _) => + getFromMap3(map)(typeName, fieldName, argName) + }) + ) + case _ => step } } @@ -204,4 +224,41 @@ object Transformer { }) case other => other } + + private def toMap2[V](t: Seq[(String, (String, V))]): Map[String, Map[String, V]] = + t.groupMap(_._1)(_._2).transform { case (_, l) => l.toMap } + + private def toMap3[V]( + t: Seq[(String, (String, (String, V)))] + ): Map[String, Map[String, Map[String, V]]] = + t.groupMap(_._1)(_._2).transform { case (_, l) => l.groupMap(_._1)(_._2).transform { case (_, l) => l.toMap } } + + private def swapMap2[V](m: Map[String, Map[String, V]]): Map[String, Map[V, String]] = + m.transform { case (_, m) => m.map(_.swap) } + + private def swapMap3[V](m: Map[String, Map[String, Map[String, V]]]): Map[String, Map[String, Map[V, String]]] = + m.transform { case (_, m) => m.transform { case (_, m) => m.map(_.swap) } } + + private def getFromMap2( + m: Map[String, Map[String, String]], + default: => String + )(k1: String, k2: String): String = + m.get(k1).flatMap(_.get(k2)).getOrElse(default) + + // Overloading to avoid boxing of Boolean + private def getFromMap2( + m: Map[String, Map[String, Boolean]], + default: Boolean = true + )(k1: String, k2: String): Boolean = { + val res = m.get(k1).flatMap(_.get(k2)) + if (res.isEmpty) default else res.get + } + + private def getFromMap3( + m: Map[String, Map[String, Map[String, Boolean]]], + default: Boolean = true + )(k1: String, k2: String, k3: String): Boolean = { + val res = m.get(k1).flatMap(_.get(k2)).flatMap(_.get(k3)) + if (res.isEmpty) default else res.get + } } diff --git a/core/src/test/scala/caliban/transformers/TransformerSpec.scala b/core/src/test/scala/caliban/transformers/TransformerSpec.scala index 3396b994e..07683cee0 100644 --- a/core/src/test/scala/caliban/transformers/TransformerSpec.scala +++ b/core/src/test/scala/caliban/transformers/TransformerSpec.scala @@ -16,7 +16,7 @@ object TransformerSpec extends ZIOSpecDefault { override def spec = suite("TransformerSpec")( test("rename type") { - val transformed: GraphQL[Any] = api.transform(Transformer.RenameType { case "InnerObject" => "Renamed" }) + val transformed: GraphQL[Any] = api.transform(Transformer.RenameType("InnerObject" -> "Renamed")) val rendered = transformed.render for { interpreter <- transformed.interpreter @@ -37,7 +37,7 @@ object TransformerSpec extends ZIOSpecDefault { ) }, test("rename field") { - val transformed: GraphQL[Any] = api.transform(Transformer.RenameField(("InnerObject", "b") -> "c")) + val transformed: GraphQL[Any] = api.transform(Transformer.RenameField("InnerObject" -> ("b" -> "c"))) val rendered = transformed.render for { interpreter <- transformed.interpreter @@ -59,8 +59,8 @@ object TransformerSpec extends ZIOSpecDefault { ) }, test("rename argument") { - val transformed: GraphQL[Any] = api.transform(Transformer.RenameArgument { case ("InnerObject", "b") => - ({ case "arg" => "arg2" }, { case "arg2" => "arg" }) + val transformed: GraphQL[Any] = api.transform(Transformer.RenameArgument { + "InnerObject" -> ("b" -> ("arg" -> "arg2")) }) val rendered = transformed.render for { @@ -86,7 +86,7 @@ object TransformerSpec extends ZIOSpecDefault { case class Query(a: String, b: Int) val api: GraphQL[Any] = graphQL(RootResolver(Query("a", 2))) - val transformed: GraphQL[Any] = api.transform(Transformer.FilterField { case ("Query", "b") => false }) + val transformed: GraphQL[Any] = api.transform(Transformer.FilterField("Query" -> ("b" -> false))) val rendered = transformed.render for { interpreter <- transformed.interpreter @@ -108,8 +108,7 @@ object TransformerSpec extends ZIOSpecDefault { case class Query(a: Args => String) val api: GraphQL[Any] = graphQL(RootResolver(Query(_.arg.getOrElse("missing")))) - val transformed: GraphQL[Any] = - api.transform(Transformer.FilterArgument { case ("Query", "a", "arg") => false }) + val transformed: GraphQL[Any] = api.transform(Transformer.FilterArgument("Query" -> ("a" -> ("arg" -> false)))) val rendered = transformed.render for { interpreter <- transformed.interpreter @@ -128,8 +127,8 @@ object TransformerSpec extends ZIOSpecDefault { }, test("combine transformers") { val transformed: GraphQL[Any] = api - .transform(Transformer.RenameType { case "InnerObject" => "Renamed" }) - .transform(Transformer.RenameField(("Renamed", "b") -> "c")) + .transform(Transformer.RenameType("InnerObject" -> "Renamed")) + .transform(Transformer.RenameField("Renamed" -> ("b" -> "c"))) val rendered = transformed.render for { interpreter <- transformed.interpreter From ede9f5dc805d0cf05d15b7f374ebb1b6ef4048e2 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Wed, 8 May 2024 15:39:08 +1000 Subject: [PATCH 3/6] Optimize and cleanup API --- .../caliban/transformers/Transformer.scala | 319 ++++++++++-------- .../transformers/TransformerSpec.scala | 10 +- 2 files changed, 182 insertions(+), 147 deletions(-) diff --git a/core/src/main/scala/caliban/transformers/Transformer.scala b/core/src/main/scala/caliban/transformers/Transformer.scala index 727090581..7a5b6bb67 100644 --- a/core/src/main/scala/caliban/transformers/Transformer.scala +++ b/core/src/main/scala/caliban/transformers/Transformer.scala @@ -6,8 +6,6 @@ import caliban.introspection.adt._ import caliban.schema.Step import caliban.schema.Step.{ FunctionStep, MetadataFunctionStep, NullStep, ObjectStep } -import scala.collection.mutable - /** * A transformer is able to modify a type, modifying its schema and the way it is resolved. */ @@ -33,74 +31,117 @@ object Transformer { */ def empty[R]: Transformer[R] = Empty - /** - * A transformer that does nothing. - */ - case object Empty extends Transformer[Any] { + private case object Empty extends Transformer[Any] { val typeVisitor: TypeVisitor = TypeVisitor.empty def transformStep[R1](step: Step[R1], field: Field): Step[R1] = step + override def |+|[R0](that: Transformer[R0]): Transformer[R0] = that + } - override def |+|[R0](that: Transformer[R0]): Transformer[R0] = new Transformer[R0] { - val typeVisitor: TypeVisitor = that.typeVisitor - def transformStep[R1 <: R0](step: Step[R1], field: Field): Step[R1] = that.transformStep(step, field) - } + object RenameType { + + /** + * A transformer that allows renaming types. + * {{{ + * RenameType( + * "Foo" -> "Bar", + * "Baz" -> "Qux" + * ) + * }}} + * @param f tuples in the format of `(OldName -> NewName)` + */ + def apply(f: (String, String)*): Transformer[Any] = + new RenameType(f.toMap) } - /** - * A transformer that allows renaming types. - * @param f a partial function that takes a type name and returns a new name for that type - */ - case class RenameType(private val f: (String, String)*) extends Transformer[Any] { - private val map = f.toMap + final private class RenameType(map: Map[String, String]) extends Transformer[Any] { + + private def renameType(t: __Type) = + t.name.flatMap(map.get).fold(t)(newName => t.copy(name = Some(newName))) - private def rename(name: String): String = map.getOrElse(name, name) + private def renameEnum(t: __EnumValue) = + map.get(t.name).fold(t)(newName => t.copy(name = newName)) val typeVisitor: TypeVisitor = - TypeVisitor.modify(t => t.copy(name = t.name.map(rename))) |+| - TypeVisitor.enumValues.modify(v => v.copy(name = rename(v.name))) + TypeVisitor.modify(renameType(_)) |+| TypeVisitor.enumValues.modify(renameEnum) def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case ObjectStep(typeName, fields) => - val res = map.get(typeName) - if (res.isEmpty) step else ObjectStep(res.get, fields) - case _ => step + case step @ ObjectStep(typeName, _) => + map.getOrElse(typeName, null) match { + case null => step + case newName => step.copy(name = newName) + } + case _ => step } } - /** - * A transformer that allows renaming fields. - * @param f a partial function that takes a type name and a field name and returns a new name for that field - */ - case class RenameField(private val f: (String, (String, String))*) extends Transformer[Any] { - private val visitorMap = toMap2(f) + object RenameField { + + /** + * A transformer that allows renaming fields on types + * + * {{{ + * RenameField( + * "TypeA" -> "foo" -> "bar", + * "TypeB" -> "baz" -> "qux", + * ) + * }}} + * + * @param f tuples in the format of `(TypeName -> oldName -> newName)` + */ + + def apply(f: ((String, String), String)*): Transformer[Any] = + new RenameField(tuplesToMap2(f: _*)) + } + + final private class RenameField(visitorMap: Map[String, Map[String, String]]) extends Transformer[Any] { private val transformMap = swapMap2(visitorMap) - val typeVisitor: TypeVisitor = { - def _get(t: __Type, name: String) = getFromMap2(visitorMap, name)(t.name.getOrElse(""), name) + private def renameField(t: __Type, field: __Field) = { + val newName = getFromMap2(visitorMap, null)(t.name.getOrElse(""), field.name) + if (newName eq null) field else field.copy(name = newName) + } - TypeVisitor.fields.modifyWith((t, field) => field.copy(name = _get(t, field.name))) |+| - TypeVisitor.inputFields.modifyWith((t, field) => field.copy(name = _get(t, field.name))) + private def renameInputField(t: __Type, input: __InputValue) = { + val newName = getFromMap2(visitorMap, null)(t.name.getOrElse(""), input.name) + if (newName eq null) input else input.copy(name = newName) } + val typeVisitor: TypeVisitor = + TypeVisitor.fields.modifyWith(renameField) |+| TypeVisitor.inputFields.modifyWith(renameInputField) + def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case ObjectStep(typeName, fields) => - ObjectStep( - typeName, - fieldName => fields(getFromMap2(transformMap, fieldName)(typeName, fieldName)) - ) - case _ => step + case step @ ObjectStep(typeName, fields) => + transformMap.getOrElse(typeName, null) match { + case null => step + case map => step.copy(fields = name => fields(map.getOrElse(name, name))) + } + case _ => step } } - /** - * A transformer that allows renaming arguments. - * @param f a partial function that takes a type name and a field name and returns another - * partial function from an argument name to a new name for that argument - */ - case class RenameArgument(private val f: (String, (String, (String, String)))*) extends Transformer[Any] { - private val visitorMap = toMap3(f) - private val transformMap = swapMap3(visitorMap) + object RenameArgument { + + /** + * A transformer that allows renaming arguments on fields + * + * {{{ + * RenameArgument( + * "TypeA" -> "fieldA" -> "foo" -> "bar", + * "TypeA" -> "fieldB" -> "baz" -> "qux", + * }}} + * + * @param f tuples in the format of `(TypeName -> fieldName -> oldArgumentName -> newArgumentName)` + */ + def apply(f: (((String, String), String), String)*): Transformer[Any] = + new RenameArgument(tuplesToMap3(f: _*)) + } + + final private class RenameArgument(visitorMap: Map[String, Map[String, Map[String, String]]]) + extends Transformer[Any] { + + private val transformMap: Map[String, Map[String, Map[String, String]]] = + swapMap3(visitorMap) val typeVisitor: TypeVisitor = TypeVisitor.fields.modifyWith((t, field) => @@ -120,97 +161,110 @@ object Transformer { ) def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case ObjectStep(typeName, fields) => - ObjectStep( - typeName, - fieldName => { - val step = fields(fieldName) - val rename = transformMap.get(typeName).flatMap(_.get(fieldName)) - if (rename.isEmpty) step - else - mapFunctionStep(step)(_.map { case (argName, input) => - rename.get.getOrElse(argName, argName) -> input - }) - } - ) - case _ => step + case step @ ObjectStep(typeName, fields) => + transformMap.getOrElse(typeName, null) match { + case null => step + case map0 => + step.copy(fields = + fieldName => + map0.getOrElse(fieldName, null) match { + case null => fields(fieldName) + case map1 => + mapFunctionStep(fields(fieldName))(_.map { case (argName, input) => + map1.getOrElse(argName, argName) -> input + }) + } + ) + } + case _ => step } } - /** - * A transformer that allows filtering fields. - * @param f a partial function that takes a type name and a field name and - * returns a boolean (true means the field should be kept) - */ - case class FilterField(private val f: (String, (String, Boolean))*) extends Transformer[Any] { - private val map = toMap2[Boolean](f) + object ExcludeField { + + /** + * A transformer that allows excluding fields from types. + * + * {{{ + * ExcludeField( + * "TypeA" -> "foo", + * "TypeB" -> "bar", + * ) + * }}} + * + * @param f tuples in the format of `(TypeName -> fieldToBeExcluded)` + */ + def apply(f: (String, String)*): Transformer[Any] = + new ExcludeField(f.groupMap(_._1)(_._2).transform((_, l) => l.toSet)) + } + + final private class ExcludeField(map: Map[String, Set[String]]) extends Transformer[Any] { + + private def shouldKeep(typeName: String, fieldName: String): Boolean = + !map.getOrElse(typeName, Set.empty).contains(fieldName) val typeVisitor: TypeVisitor = { - val _get = getFromMap2(map) _ - TypeVisitor.fields.filterWith((t, field) => _get(t.name.getOrElse(""), field.name)) |+| - TypeVisitor.inputFields.filterWith((t, field) => _get(t.name.getOrElse(""), field.name)) + TypeVisitor.fields.filterWith((t, field) => shouldKeep(t.name.getOrElse(""), field.name)) |+| + TypeVisitor.inputFields.filterWith((t, field) => shouldKeep(t.name.getOrElse(""), field.name)) } def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case ObjectStep(typeName, fields) => - ObjectStep( - typeName, - fieldName => if (getFromMap2(map, default = true)(typeName, fieldName)) fields(fieldName) else NullStep - ) - case _ => step + case step @ ObjectStep(typeName, fields) => + map.getOrElse(typeName, null) match { + case null => step + case exclude => step.copy(fields = name => if (!exclude(name)) fields(name) else NullStep) + } + case _ => step } } - /** - * A transformer that allows filtering types. - * - * @param f a partial function that takes a type name and an interface name and - * returns a boolean (true means the type should be kept) - */ - case class FilterInterface(private val f: (String, (String, Boolean))*) extends Transformer[Any] { - private val map = toMap2(f) - - val typeVisitor: TypeVisitor = - TypeVisitor.modify(t => - t.copy(interfaces = - () => - t.interfaces() - .map( - _.filter(interface => getFromMap2(map)(t.name.getOrElse(""), interface.name.getOrElse(""))) - ) - ) + object ExcludeArgument { + + /** + * A transformer that allows excluding arguments from fields + * + * {{{ + * ExcludeArgument( + * "TypeA" -> "fieldA" -> "arg", + * "TypeA" -> "fieldB" -> "arg2", + * ) + * }}} + * + * @param f tuples in the format of `(TypeName -> fieldName -> argumentToBeExcluded)` + */ + def apply(f: ((String, String), String)*): Transformer[Any] = + new ExcludeArgument( + f + .groupMap(_._1._1)(v => v._1._2 -> v._2) + .transform((_, v) => v.groupMap(_._1)(_._2).transform((_, v) => v.toSet)) ) - - def transformStep[R](step: Step[R], field: Field): Step[R] = step } + final private class ExcludeArgument(map: Map[String, Map[String, Set[String]]]) extends Transformer[Any] { - /** - * A transformer that allows filtering arguments. - * @param f a partial function that takes a type name, a field name and an argument name and - * returns a boolean (true means the argument should be kept) - */ - case class FilterArgument(private val f: (String, (String, (String, Boolean)))*) extends Transformer[Any] { - private val map = toMap3(f) + private def shouldKeep(typeName: String, fieldName: String, argName: String): Boolean = + !getFromMap2(map, Set.empty[String])(typeName, fieldName).contains(argName) val typeVisitor: TypeVisitor = TypeVisitor.fields.modifyWith((t, field) => field.copy(args = field .args(_) - .filter(arg => getFromMap3(map)(t.name.getOrElse(""), field.name, arg.name)) + .filter(arg => shouldKeep(t.name.getOrElse(""), field.name, arg.name)) ) ) def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case ObjectStep(typeName, fields) => - ObjectStep( - typeName, - fieldName => - mapFunctionStep(fields(fieldName))(_.filter { case (argName, _) => - getFromMap3(map)(typeName, fieldName, argName) + case step @ ObjectStep(typeName, fields) => + map.getOrElse(typeName, null) match { + case null => step + case map1 => + step.copy(fields = fieldName => { + val s = map1.getOrElse(fieldName, null) + if (s eq null) fields(fieldName) + else mapFunctionStep(fields(fieldName))(_.filterNot { case (argName, _) => s.contains(argName) }) }) - ) - case _ => step + } + case _ => step } } @@ -225,40 +279,21 @@ object Transformer { case other => other } - private def toMap2[V](t: Seq[(String, (String, V))]): Map[String, Map[String, V]] = - t.groupMap(_._1)(_._2).transform { case (_, l) => l.toMap } + private def tuplesToMap2(f: ((String, String), String)*): Map[String, Map[String, String]] = + f.groupMap(_._1._1)(v => v._1._2 -> v._2).transform((_, l) => l.toMap) - private def toMap3[V]( - t: Seq[(String, (String, (String, V)))] - ): Map[String, Map[String, Map[String, V]]] = - t.groupMap(_._1)(_._2).transform { case (_, l) => l.groupMap(_._1)(_._2).transform { case (_, l) => l.toMap } } + private def tuplesToMap3(f: (((String, String), String), String)*): Map[String, Map[String, Map[String, String]]] = + f.groupMap(_._1._1._1)(v => v._1._1._2 -> v._1._2 -> v._2).transform((_, l) => tuplesToMap2(l: _*)) private def swapMap2[V](m: Map[String, Map[String, V]]): Map[String, Map[V, String]] = - m.transform { case (_, m) => m.map(_.swap) } + m.transform((_, m) => m.map(_.swap)) private def swapMap3[V](m: Map[String, Map[String, Map[String, V]]]): Map[String, Map[String, Map[V, String]]] = - m.transform { case (_, m) => m.transform { case (_, m) => m.map(_.swap) } } + m.transform((_, m) => swapMap2(m)) - private def getFromMap2( - m: Map[String, Map[String, String]], - default: => String - )(k1: String, k2: String): String = + private def getFromMap2[V]( + m: Map[String, Map[String, V]], + default: => V + )(k1: String, k2: String): V = m.get(k1).flatMap(_.get(k2)).getOrElse(default) - - // Overloading to avoid boxing of Boolean - private def getFromMap2( - m: Map[String, Map[String, Boolean]], - default: Boolean = true - )(k1: String, k2: String): Boolean = { - val res = m.get(k1).flatMap(_.get(k2)) - if (res.isEmpty) default else res.get - } - - private def getFromMap3( - m: Map[String, Map[String, Map[String, Boolean]]], - default: Boolean = true - )(k1: String, k2: String, k3: String): Boolean = { - val res = m.get(k1).flatMap(_.get(k2)).flatMap(_.get(k3)) - if (res.isEmpty) default else res.get - } } diff --git a/core/src/test/scala/caliban/transformers/TransformerSpec.scala b/core/src/test/scala/caliban/transformers/TransformerSpec.scala index 07683cee0..3903c2c93 100644 --- a/core/src/test/scala/caliban/transformers/TransformerSpec.scala +++ b/core/src/test/scala/caliban/transformers/TransformerSpec.scala @@ -37,7 +37,7 @@ object TransformerSpec extends ZIOSpecDefault { ) }, test("rename field") { - val transformed: GraphQL[Any] = api.transform(Transformer.RenameField("InnerObject" -> ("b" -> "c"))) + val transformed: GraphQL[Any] = api.transform(Transformer.RenameField("InnerObject" -> "b" -> "c")) val rendered = transformed.render for { interpreter <- transformed.interpreter @@ -60,7 +60,7 @@ object TransformerSpec extends ZIOSpecDefault { }, test("rename argument") { val transformed: GraphQL[Any] = api.transform(Transformer.RenameArgument { - "InnerObject" -> ("b" -> ("arg" -> "arg2")) + "InnerObject" -> "b" -> "arg" -> "arg2" }) val rendered = transformed.render for { @@ -86,7 +86,7 @@ object TransformerSpec extends ZIOSpecDefault { case class Query(a: String, b: Int) val api: GraphQL[Any] = graphQL(RootResolver(Query("a", 2))) - val transformed: GraphQL[Any] = api.transform(Transformer.FilterField("Query" -> ("b" -> false))) + val transformed: GraphQL[Any] = api.transform(Transformer.ExcludeField("Query" -> "b")) val rendered = transformed.render for { interpreter <- transformed.interpreter @@ -108,7 +108,7 @@ object TransformerSpec extends ZIOSpecDefault { case class Query(a: Args => String) val api: GraphQL[Any] = graphQL(RootResolver(Query(_.arg.getOrElse("missing")))) - val transformed: GraphQL[Any] = api.transform(Transformer.FilterArgument("Query" -> ("a" -> ("arg" -> false)))) + val transformed: GraphQL[Any] = api.transform(Transformer.ExcludeArgument("Query" -> "a" -> "arg")) val rendered = transformed.render for { interpreter <- transformed.interpreter @@ -128,7 +128,7 @@ object TransformerSpec extends ZIOSpecDefault { test("combine transformers") { val transformed: GraphQL[Any] = api .transform(Transformer.RenameType("InnerObject" -> "Renamed")) - .transform(Transformer.RenameField("Renamed" -> ("b" -> "c"))) + .transform(Transformer.RenameField("Renamed" -> "b" -> "c")) val rendered = transformed.render for { interpreter <- transformed.interpreter From 7cdf857d5b6fd8692e3e0d40beb697a6838920d5 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Wed, 8 May 2024 15:47:29 +1000 Subject: [PATCH 4/6] Make Scala 2.12 happy --- core/src/main/scala/caliban/execution/Executor.scala | 5 +---- core/src/main/scala/caliban/transformers/Transformer.scala | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index 02219dd30..5f05a7846 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -132,8 +132,6 @@ object Executor { wrapPureValues: Boolean )(implicit trace: Trace) { - private val isEmptyTransformer = transformer.isEmpty - def reduceStep( step: Step[R], currentField: Field, @@ -264,8 +262,7 @@ object Executor { try step(input) catch { case NonFatal(e) => Step.fail(e) } - val step0 = if (isEmptyTransformer) step else transformer.transformStep(step, currentField) - step0 match { + transformer.transformStep(step, currentField) match { case s: PureStep => s case QueryStep(inner) => reduceQuery(inner) case ObjectStep(objectName, fields) => reduceObjectStep(objectName, fields) diff --git a/core/src/main/scala/caliban/transformers/Transformer.scala b/core/src/main/scala/caliban/transformers/Transformer.scala index 7a5b6bb67..5a956fb4a 100644 --- a/core/src/main/scala/caliban/transformers/Transformer.scala +++ b/core/src/main/scala/caliban/transformers/Transformer.scala @@ -5,6 +5,7 @@ import caliban.execution.Field import caliban.introspection.adt._ import caliban.schema.Step import caliban.schema.Step.{ FunctionStep, MetadataFunctionStep, NullStep, ObjectStep } +import scala.collection.compat._ /** * A transformer is able to modify a type, modifying its schema and the way it is resolved. From 6f036126f285aa3ab1fb835da0226f8a6922f957 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Wed, 8 May 2024 19:18:01 +1000 Subject: [PATCH 5/6] Cleanups --- .../scala/caliban/execution/Executor.scala | 18 +- .../caliban/transformers/Transformer.scala | 184 +++++++++--------- 2 files changed, 98 insertions(+), 104 deletions(-) diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index 5f05a7846..feb003b5a 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -262,14 +262,16 @@ object Executor { try step(input) catch { case NonFatal(e) => Step.fail(e) } - transformer.transformStep(step, currentField) match { - case s: PureStep => s - case QueryStep(inner) => reduceQuery(inner) - case ObjectStep(objectName, fields) => reduceObjectStep(objectName, fields) - case FunctionStep(step) => reduceStep(wrapFn(step, arguments), currentField, Map.empty, path) - case MetadataFunctionStep(step) => reduceStep(wrapFn(step, currentField), currentField, arguments, path) - case ListStep(steps) => reduceListStep(steps) - case StreamStep(stream) => reduceStream(stream) + step match { + case s: PureStep => s + case s: QueryStep[R] => reduceQuery(s.query) + case s: ObjectStep[R] => + val obj = transformer.transformStep(s, currentField) + reduceObjectStep(obj.name, obj.fields) + case s: FunctionStep[R] => reduceStep(wrapFn(s.step, arguments), currentField, Map.empty, path) + case s: MetadataFunctionStep[R] => reduceStep(wrapFn(s.step, currentField), currentField, arguments, path) + case s: ListStep[R] => reduceListStep(s.steps) + case s: StreamStep[R] => reduceStream(s.inner) } } diff --git a/core/src/main/scala/caliban/transformers/Transformer.scala b/core/src/main/scala/caliban/transformers/Transformer.scala index 5a956fb4a..31cc44f06 100644 --- a/core/src/main/scala/caliban/transformers/Transformer.scala +++ b/core/src/main/scala/caliban/transformers/Transformer.scala @@ -13,16 +13,14 @@ import scala.collection.compat._ abstract class Transformer[-R] { self => val typeVisitor: TypeVisitor - def transformStep[R1 <: R](step: Step[R1], field: Field): Step[R1] + def transformStep[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] - def |+|[R0 <: R](that: Transformer[R0]): Transformer[R0] = new Transformer[R0] { - val typeVisitor: TypeVisitor = self.typeVisitor |+| that.typeVisitor - - def transformStep[R1 <: R0](step: Step[R1], field: Field): Step[R1] = - that.transformStep(self.transformStep(step, field), field) - } - - final def isEmpty: Boolean = self eq Transformer.Empty + def |+|[R0 <: R](that: Transformer[R0]): Transformer[R0] = + (self, that) match { + case (l, Transformer.Empty) => l + case (Transformer.Empty, r) => r + case _ => new Transformer.Combined[R0](self, that) + } } object Transformer { @@ -35,8 +33,7 @@ object Transformer { private case object Empty extends Transformer[Any] { val typeVisitor: TypeVisitor = TypeVisitor.empty - def transformStep[R1](step: Step[R1], field: Field): Step[R1] = step - override def |+|[R0](that: Transformer[R0]): Transformer[R0] = that + def transformStep[R1](step: ObjectStep[R1], field: Field): ObjectStep[R1] = step } object RenameType { @@ -57,23 +54,22 @@ object Transformer { final private class RenameType(map: Map[String, String]) extends Transformer[Any] { - private def renameType(t: __Type) = - t.name.flatMap(map.get).fold(t)(newName => t.copy(name = Some(newName))) - - private def renameEnum(t: __EnumValue) = - map.get(t.name).fold(t)(newName => t.copy(name = newName)) - - val typeVisitor: TypeVisitor = - TypeVisitor.modify(renameType(_)) |+| TypeVisitor.enumValues.modify(renameEnum) - - def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case step @ ObjectStep(typeName, _) => - map.getOrElse(typeName, null) match { - case null => step - case newName => step.copy(name = newName) - } - case _ => step + val typeVisitor: TypeVisitor = { + val renameType = { (t: __Type) => + t.name.flatMap(map.get).fold(t)(newName => t.copy(name = Some(newName))) + } + val renameEnum = { (t: __EnumValue) => + map.get(t.name).fold(t)(newName => t.copy(name = newName)) + } + + TypeVisitor.modify(renameType) |+| TypeVisitor.enumValues.modify(renameEnum) } + + def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + map.getOrElse(step.name, null) match { + case null => step + case newName => step.copy(name = newName) + } } object RenameField { @@ -98,27 +94,26 @@ object Transformer { final private class RenameField(visitorMap: Map[String, Map[String, String]]) extends Transformer[Any] { private val transformMap = swapMap2(visitorMap) - private def renameField(t: __Type, field: __Field) = { - val newName = getFromMap2(visitorMap, null)(t.name.getOrElse(""), field.name) - if (newName eq null) field else field.copy(name = newName) - } + val typeVisitor: TypeVisitor = { + def getName(t: __Type, name: String) = getFromMap2(visitorMap, null)(t.name.getOrElse(""), name) - private def renameInputField(t: __Type, input: __InputValue) = { - val newName = getFromMap2(visitorMap, null)(t.name.getOrElse(""), input.name) - if (newName eq null) input else input.copy(name = newName) - } + val renameField = { (t: __Type, field: __Field) => + val newName = getName(t, field.name) + if (newName eq null) field else field.copy(name = newName) + } - val typeVisitor: TypeVisitor = + val renameInputField = { (t: __Type, input: __InputValue) => + val newName = getName(t, input.name) + if (newName eq null) input else input.copy(name = newName) + } TypeVisitor.fields.modifyWith(renameField) |+| TypeVisitor.inputFields.modifyWith(renameInputField) - - def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case step @ ObjectStep(typeName, fields) => - transformMap.getOrElse(typeName, null) match { - case null => step - case map => step.copy(fields = name => fields(map.getOrElse(name, name))) - } - case _ => step } + + def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + transformMap.getOrElse(step.name, null) match { + case null => step + case map => step.copy(fields = name => step.fields(map.getOrElse(name, name))) + } } object RenameArgument { @@ -130,6 +125,7 @@ object Transformer { * RenameArgument( * "TypeA" -> "fieldA" -> "foo" -> "bar", * "TypeA" -> "fieldB" -> "baz" -> "qux", + * ) * }}} * * @param f tuples in the format of `(TypeName -> fieldName -> oldArgumentName -> newArgumentName)` @@ -141,44 +137,35 @@ object Transformer { final private class RenameArgument(visitorMap: Map[String, Map[String, Map[String, String]]]) extends Transformer[Any] { - private val transformMap: Map[String, Map[String, Map[String, String]]] = - swapMap3(visitorMap) + private val transformMap: Map[String, Map[String, Map[String, String]]] = swapMap3(visitorMap) val typeVisitor: TypeVisitor = TypeVisitor.fields.modifyWith((t, field) => visitorMap.get(t.name.getOrElse("")).flatMap(_.get(field.name)) match { - case Some(rename) => - field.copy(args = - field - .args(_) - .map(arg => - rename - .get(arg.name) - .fold(arg)(newName => arg.copy(name = newName)) - ) - ) - case None => field + case Some(renames) => + field.copy(args = field.args(_).map { arg => + renames.get(arg.name).fold(arg)(newName => arg.copy(name = newName)) + }) + case None => field } ) - def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case step @ ObjectStep(typeName, fields) => - transformMap.getOrElse(typeName, null) match { - case null => step - case map0 => - step.copy(fields = - fieldName => - map0.getOrElse(fieldName, null) match { - case null => fields(fieldName) - case map1 => - mapFunctionStep(fields(fieldName))(_.map { case (argName, input) => - map1.getOrElse(argName, argName) -> input - }) - } - ) - } - case _ => step - } + def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + transformMap.getOrElse(step.name, null) match { + case null => step + case map0 => + val fields = step.fields + step.copy(fields = + fieldName => + map0.getOrElse(fieldName, null) match { + case null => fields(fieldName) + case map1 => + mapFunctionStep(fields(fieldName))(_.map { case (argName, input) => + map1.getOrElse(argName, argName) -> input + }) + } + ) + } } object ExcludeField { @@ -209,14 +196,11 @@ object Transformer { TypeVisitor.inputFields.filterWith((t, field) => shouldKeep(t.name.getOrElse(""), field.name)) } - def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case step @ ObjectStep(typeName, fields) => - map.getOrElse(typeName, null) match { - case null => step - case exclude => step.copy(fields = name => if (!exclude(name)) fields(name) else NullStep) - } - case _ => step - } + def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + map.getOrElse(step.name, null) match { + case null => step + case excl => step.copy(fields = name => if (!excl(name)) step.fields(name) else NullStep) + } } object ExcludeArgument { @@ -240,6 +224,7 @@ object Transformer { .transform((_, v) => v.groupMap(_._1)(_._2).transform((_, v) => v.toSet)) ) } + final private class ExcludeArgument(map: Map[String, Map[String, Set[String]]]) extends Transformer[Any] { private def shouldKeep(typeName: String, fieldName: String, argName: String): Boolean = @@ -254,19 +239,26 @@ object Transformer { ) ) - def transformStep[R](step: Step[R], field: Field): Step[R] = step match { - case step @ ObjectStep(typeName, fields) => - map.getOrElse(typeName, null) match { - case null => step - case map1 => - step.copy(fields = fieldName => { - val s = map1.getOrElse(fieldName, null) - if (s eq null) fields(fieldName) - else mapFunctionStep(fields(fieldName))(_.filterNot { case (argName, _) => s.contains(argName) }) - }) - } - case _ => step - } + def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + map.getOrElse(step.name, null) match { + case null => step + case inner => + val fields = step.fields + step.copy(fields = + fieldName => + inner.getOrElse(fieldName, null) match { + case null => fields(fieldName) + case excl => mapFunctionStep(fields(fieldName))(_.filterNot { case (argName, _) => excl(argName) }) + } + ) + } + } + + final private class Combined[-R](left: Transformer[R], right: Transformer[R]) extends Transformer[R] { + val typeVisitor: TypeVisitor = left.typeVisitor |+| right.typeVisitor + + def transformStep[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] = + right.transformStep(left.transformStep(step, field), field) } private def mapFunctionStep[R](step: Step[R])(f: Map[String, InputValue] => Map[String, InputValue]): Step[R] = From 1f55ac16abd65ff04ce4490d33ef4313c09654b2 Mon Sep 17 00:00:00 2001 From: Kyri Petrou Date: Thu, 9 May 2024 09:21:20 +1000 Subject: [PATCH 6/6] Pre-check whether to apply transformers in Transformer.Combined --- .../scala/caliban/execution/Executor.scala | 4 +- .../caliban/transformers/Transformer.scala | 68 ++++++++++++++----- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/caliban/execution/Executor.scala b/core/src/main/scala/caliban/execution/Executor.scala index feb003b5a..1bb67cd8a 100644 --- a/core/src/main/scala/caliban/execution/Executor.scala +++ b/core/src/main/scala/caliban/execution/Executor.scala @@ -265,9 +265,7 @@ object Executor { step match { case s: PureStep => s case s: QueryStep[R] => reduceQuery(s.query) - case s: ObjectStep[R] => - val obj = transformer.transformStep(s, currentField) - reduceObjectStep(obj.name, obj.fields) + case s: ObjectStep[R] => val t = transformer(s, currentField); reduceObjectStep(t.name, t.fields) case s: FunctionStep[R] => reduceStep(wrapFn(s.step, arguments), currentField, Map.empty, path) case s: MetadataFunctionStep[R] => reduceStep(wrapFn(s.step, currentField), currentField, arguments, path) case s: ListStep[R] => reduceListStep(s.steps) diff --git a/core/src/main/scala/caliban/transformers/Transformer.scala b/core/src/main/scala/caliban/transformers/Transformer.scala index 31cc44f06..a169e7b85 100644 --- a/core/src/main/scala/caliban/transformers/Transformer.scala +++ b/core/src/main/scala/caliban/transformers/Transformer.scala @@ -5,7 +5,9 @@ import caliban.execution.Field import caliban.introspection.adt._ import caliban.schema.Step import caliban.schema.Step.{ FunctionStep, MetadataFunctionStep, NullStep, ObjectStep } + import scala.collection.compat._ +import scala.collection.mutable /** * A transformer is able to modify a type, modifying its schema and the way it is resolved. @@ -13,7 +15,16 @@ import scala.collection.compat._ abstract class Transformer[-R] { self => val typeVisitor: TypeVisitor - def transformStep[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] + /** + * Set of type names that this transformer applies to. + * Needed for applying optimizations when combining transformers. + */ + protected val typeNames: collection.Set[String] + + protected def transformStep[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] + + def apply[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] = + transformStep(step, field) def |+|[R0 <: R](that: Transformer[R0]): Transformer[R0] = (self, that) match { @@ -33,7 +44,9 @@ object Transformer { private case object Empty extends Transformer[Any] { val typeVisitor: TypeVisitor = TypeVisitor.empty - def transformStep[R1](step: ObjectStep[R1], field: Field): ObjectStep[R1] = step + protected val typeNames: Set[String] = Set.empty + + protected def transformStep[R1](step: ObjectStep[R1], field: Field): ObjectStep[R1] = step } object RenameType { @@ -49,7 +62,7 @@ object Transformer { * @param f tuples in the format of `(OldName -> NewName)` */ def apply(f: (String, String)*): Transformer[Any] = - new RenameType(f.toMap) + if (f.isEmpty) Empty else new RenameType(f.toMap) } final private class RenameType(map: Map[String, String]) extends Transformer[Any] { @@ -65,7 +78,9 @@ object Transformer { TypeVisitor.modify(renameType) |+| TypeVisitor.enumValues.modify(renameEnum) } - def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + protected val typeNames: Set[String] = map.keySet + + protected def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = map.getOrElse(step.name, null) match { case null => step case newName => step.copy(name = newName) @@ -88,7 +103,7 @@ object Transformer { */ def apply(f: ((String, String), String)*): Transformer[Any] = - new RenameField(tuplesToMap2(f: _*)) + if (f.isEmpty) Empty else new RenameField(tuplesToMap2(f: _*)) } final private class RenameField(visitorMap: Map[String, Map[String, String]]) extends Transformer[Any] { @@ -109,7 +124,9 @@ object Transformer { TypeVisitor.fields.modifyWith(renameField) |+| TypeVisitor.inputFields.modifyWith(renameInputField) } - def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + protected val typeNames: Set[String] = transformMap.keySet + + protected def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = transformMap.getOrElse(step.name, null) match { case null => step case map => step.copy(fields = name => step.fields(map.getOrElse(name, name))) @@ -131,7 +148,7 @@ object Transformer { * @param f tuples in the format of `(TypeName -> fieldName -> oldArgumentName -> newArgumentName)` */ def apply(f: (((String, String), String), String)*): Transformer[Any] = - new RenameArgument(tuplesToMap3(f: _*)) + if (f.isEmpty) Empty else new RenameArgument(tuplesToMap3(f: _*)) } final private class RenameArgument(visitorMap: Map[String, Map[String, Map[String, String]]]) @@ -150,7 +167,9 @@ object Transformer { } ) - def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + protected val typeNames: Set[String] = transformMap.keySet + + protected def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = transformMap.getOrElse(step.name, null) match { case null => step case map0 => @@ -183,7 +202,7 @@ object Transformer { * @param f tuples in the format of `(TypeName -> fieldToBeExcluded)` */ def apply(f: (String, String)*): Transformer[Any] = - new ExcludeField(f.groupMap(_._1)(_._2).transform((_, l) => l.toSet)) + if (f.isEmpty) Empty else new ExcludeField(f.groupMap(_._1)(_._2).transform((_, l) => l.toSet)) } final private class ExcludeField(map: Map[String, Set[String]]) extends Transformer[Any] { @@ -196,7 +215,9 @@ object Transformer { TypeVisitor.inputFields.filterWith((t, field) => shouldKeep(t.name.getOrElse(""), field.name)) } - def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + protected val typeNames: Set[String] = map.keySet + + protected def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = map.getOrElse(step.name, null) match { case null => step case excl => step.copy(fields = name => if (!excl(name)) step.fields(name) else NullStep) @@ -218,11 +239,13 @@ object Transformer { * @param f tuples in the format of `(TypeName -> fieldName -> argumentToBeExcluded)` */ def apply(f: ((String, String), String)*): Transformer[Any] = - new ExcludeArgument( - f - .groupMap(_._1._1)(v => v._1._2 -> v._2) - .transform((_, v) => v.groupMap(_._1)(_._2).transform((_, v) => v.toSet)) - ) + if (f.isEmpty) Empty + else + new ExcludeArgument( + f + .groupMap(_._1._1)(v => v._1._2 -> v._2) + .transform((_, v) => v.groupMap(_._1)(_._2).transform((_, v) => v.toSet)) + ) } final private class ExcludeArgument(map: Map[String, Map[String, Set[String]]]) extends Transformer[Any] { @@ -239,7 +262,9 @@ object Transformer { ) ) - def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = + protected val typeNames: Set[String] = map.keySet + + protected def transformStep[R](step: ObjectStep[R], field: Field): ObjectStep[R] = map.getOrElse(step.name, null) match { case null => step case inner => @@ -257,8 +282,17 @@ object Transformer { final private class Combined[-R](left: Transformer[R], right: Transformer[R]) extends Transformer[R] { val typeVisitor: TypeVisitor = left.typeVisitor |+| right.typeVisitor - def transformStep[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] = + protected val typeNames: mutable.HashSet[String] = { + val set = mutable.HashSet.from(left.typeNames) + set ++= right.typeNames + set + } + + protected def transformStep[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] = right.transformStep(left.transformStep(step, field), field) + + override def apply[R1 <: R](step: ObjectStep[R1], field: Field): ObjectStep[R1] = + if (typeNames(step.name)) transformStep(step, field) else step } private def mapFunctionStep[R](step: Step[R])(f: Map[String, InputValue] => Map[String, InputValue]): Step[R] =