diff --git a/core/src/main/scala/caliban/introspection/adt/__Field.scala b/core/src/main/scala/caliban/introspection/adt/__Field.scala index f9e0d8604..f4c597ded 100644 --- a/core/src/main/scala/caliban/introspection/adt/__Field.scala +++ b/core/src/main/scala/caliban/introspection/adt/__Field.scala @@ -35,4 +35,7 @@ case class __Field( args(__DeprecatedArgs.include) private[caliban] lazy val _type: __Type = `type`() + + private[caliban] lazy val allArgNames: Set[String] = + allArgs.view.map(_.name).toSet } diff --git a/core/src/main/scala/caliban/transformers/Transformer.scala b/core/src/main/scala/caliban/transformers/Transformer.scala index d69d860b2..11a157f49 100644 --- a/core/src/main/scala/caliban/transformers/Transformer.scala +++ b/core/src/main/scala/caliban/transformers/Transformer.scala @@ -1,6 +1,5 @@ package caliban.transformers -import caliban.CalibanError.ValidationError import caliban.InputValue import caliban.execution.Field import caliban.introspection.adt._ @@ -278,6 +277,7 @@ object Transformer { * ) * }}} * + * @note the '''argument must be optional''', otherwise the filter will be silently ignored * @param f tuples in the format of `(TypeName -> fieldName -> argumentToBeExcluded)` */ def apply(f: ((String, String), String)*): Transformer[Any] = @@ -292,15 +292,15 @@ object Transformer { final private class ExcludeArgument(map: Map[String, Map[String, Set[String]]]) extends Transformer[Any] { - private def shouldKeep(typeName: String, fieldName: String, argName: String): Boolean = - !getFromMap2(map, Set.empty[String])(typeName, fieldName).contains(argName) + private def shouldExclude(typeName: String, fieldName: String, arg: __InputValue): Boolean = + arg._type.isNullable && getFromMap2(map, Set.empty[String])(typeName, fieldName).contains(arg.name) val typeVisitor: TypeVisitor = TypeVisitor.fields.modifyWith((t, field) => field.copy(args = field .args(_) - .filter(arg => shouldKeep(t.name.getOrElse(""), field.name, arg.name)) + .filterNot(arg => shouldExclude(t.name.getOrElse(""), field.name, arg)) ) ) @@ -313,9 +313,11 @@ object Transformer { val fields = step.fields step.copy(fields = fieldName => - inner.getOrElse(fieldName, null) match { - case null => fields(fieldName) - case excl => mapFunctionStep(fields(fieldName))(_.filterNot { case (argName, _) => excl(argName) }) + if (inner.contains(fieldName)) { + val args = field.fieldType.allFieldsMap(fieldName).allArgNames + mapFunctionStep(fields(fieldName))(_.filterNot { case (argName, _) => !args.contains(argName) }) + } else { + fields(fieldName) } ) } diff --git a/core/src/test/scala/caliban/transformers/TransformerSpec.scala b/core/src/test/scala/caliban/transformers/TransformerSpec.scala index 6be41795d..3dd2d64f2 100644 --- a/core/src/test/scala/caliban/transformers/TransformerSpec.scala +++ b/core/src/test/scala/caliban/transformers/TransformerSpec.scala @@ -139,28 +139,62 @@ object TransformerSpec extends ZIOSpecDefault { |}""".stripMargin ) }, - test("filter argument") { - case class Args(arg: Option[String]) - case class Query(a: Args => String) - val api: GraphQL[Any] = graphQL(RootResolver(Query(_.arg.getOrElse("missing")))) + suite("ExcludeArgument")( + test("filter nullable argument") { + case class Args(arg: Option[String]) + case class Query(a: Args => String) + val api: GraphQL[Any] = graphQL(RootResolver(Query(_.arg.getOrElse("missing")))) - val transformed: GraphQL[Any] = api.transform(Transformer.ExcludeArgument("Query" -> "a" -> "arg")) - val rendered = transformed.render - for { - interpreter <- transformed.interpreter - result <- interpreter.execute("""{ a }""").map(_.data.toString) - } yield assertTrue( - result == """{"a":"missing"}""", - rendered == - """schema { - | query: Query - |} - | - |type Query { - | a: String! - |}""".stripMargin - ) - }, + val transformed: GraphQL[Any] = api.transform(Transformer.ExcludeArgument("Query" -> "a" -> "arg")) + val rendered = transformed.render + for { + interpreter <- transformed.interpreter + result <- interpreter.execute("""{ a }""").map(_.data.toString) + } yield assertTrue( + result == """{"a":"missing"}""", + rendered == + """schema { + | query: Query + |} + | + |type Query { + | a: String! + |}""".stripMargin + ) + }, + test("cannot filter non-nullable arguments") { + case class Args(arg1: String, arg2: Option[String], arg3: Option[String]) + case class Query(a: Args => String) + val api: GraphQL[Any] = graphQL( + RootResolver( + Query(t => s"a1:${t.arg1} a2:${t.arg2.getOrElse("missing")} a3:${t.arg3.getOrElse("missing")}") + ) + ) + + val transformed: GraphQL[Any] = api.transform( + Transformer.ExcludeArgument( + "Query" -> "a" -> "arg1", + "Query" -> "a" -> "arg2" + ) + ) + val rendered = transformed.render + for { + _ <- Configurator.setSkipValidation(true) + interpreter <- transformed.interpreter + result <- interpreter.execute("""{ a(arg1:"foo", arg2:"bar", arg3:"baz") }""").map(_.data.toString) + } yield assertTrue( + result == """{"a":"a1:foo a2:missing a3:baz"}""", + rendered == + """schema { + | query: Query + |} + | + |type Query { + | a(arg1: String!, arg3: String): String! + |}""".stripMargin + ) + } + ), test("combine transformers") { val transformed: GraphQL[Any] = api .transform(Transformer.RenameType("InnerObject" -> "Renamed"))