diff --git a/core/src/main/scala/caliban/schema/RootSchemaBuilder.scala b/core/src/main/scala/caliban/schema/RootSchemaBuilder.scala index 05c950878..eaedff91d 100644 --- a/core/src/main/scala/caliban/schema/RootSchemaBuilder.scala +++ b/core/src/main/scala/caliban/schema/RootSchemaBuilder.scala @@ -17,10 +17,11 @@ case class RootSchemaBuilder[-R]( ) def types: List[__Type] = { - val empty = additionalTypes - (query.map(_.opType).fold(empty)(collectTypes(_)) ++ - mutation.map(_.opType).fold(empty)(collectTypes(_)) ++ - subscription.map(_.opType).fold(empty)(collectTypes(_))) + val init = additionalTypes.foldLeft(List.empty[__Type]) { case (acc, t) => collectTypes(t, acc) } + (init ++ + query.map(_.opType).fold(List.empty[__Type])(collectTypes(_, init)) ++ + mutation.map(_.opType).fold(List.empty[__Type])(collectTypes(_, init)) ++ + subscription.map(_.opType).fold(List.empty[__Type])(collectTypes(_, init))) .groupBy(t => (t.name, t.kind, t.origin)) .flatMap(_._2.headOption) .toList diff --git a/core/src/test/scala/caliban/schema/SchemaSpec.scala b/core/src/test/scala/caliban/schema/SchemaSpec.scala index 22b3e8257..ef758a1f8 100644 --- a/core/src/test/scala/caliban/schema/SchemaSpec.scala +++ b/core/src/test/scala/caliban/schema/SchemaSpec.scala @@ -181,6 +181,40 @@ object SchemaSpec extends DefaultRunnableSpec { | a: String! |}""".stripMargin assertTrue(gql.render == expected) + }, + test("Pass interface to withAdditionalTypes") { + @GQLInterface + sealed trait Interface + + case class A(s: String) extends Interface + case class B(s: String) extends Interface + + case class Query(a: A, b: B) + + val interfaceType = Schema.gen[Any, Interface].toType_() + + val gql = graphQL(RootResolver(Query(A("a"), B("b")))).withAdditionalTypes(List(interfaceType)) + val expected = """schema { + | query: Query + |} + | + |interface Interface { + | s: String! + |} + | + |type A implements Interface { + | s: String! + |} + | + |type B implements Interface { + | s: String! + |} + | + |type Query { + | a: A! + | b: B! + |}""".stripMargin + assertTrue(gql.render == expected) } ) diff --git a/vuepress/docs/faq/README.md b/vuepress/docs/faq/README.md index 4080e609d..9427d9558 100644 --- a/vuepress/docs/faq/README.md +++ b/vuepress/docs/faq/README.md @@ -107,8 +107,8 @@ case class B(s: String) extends Interface case class Query(a: A, b: B) -val interfaceSchema = Schema.gen[Interface] +val interfaceType = Schema.gen[Interface].toType_() -val api = graphQL(RootResolver(Query(A("a"), B("b")))).withAdditionalTypes(List(interfaceSchema.toType_())) +val api = graphQL(RootResolver(Query(A("a"), B("b")))).withAdditionalTypes(List(interfaceType)) ```