diff --git a/src/main/kotlin/graphql/kickstart/tools/DictionaryTypeResolver.kt b/src/main/kotlin/graphql/kickstart/tools/DictionaryTypeResolver.kt index e5f0a901..d7c40b6e 100644 --- a/src/main/kotlin/graphql/kickstart/tools/DictionaryTypeResolver.kt +++ b/src/main/kotlin/graphql/kickstart/tools/DictionaryTypeResolver.kt @@ -13,8 +13,7 @@ import graphql.schema.TypeResolver * @author Andrew Potter */ internal abstract class DictionaryTypeResolver( - private val dictionary: BiMap>, - private val types: Map + private val dictionary: BiMap> ) : TypeResolver { private fun getTypeDefinition(clazz: Class): TypeDefinition<*>? { return dictionary[clazz] @@ -25,7 +24,7 @@ internal abstract class DictionaryTypeResolver( override fun getType(env: TypeResolutionEnvironment): GraphQLObjectType? { val clazz = env.getObject().javaClass val name = getTypeDefinition(clazz)?.name ?: clazz.simpleName - return types[name] ?: throw TypeResolverError(getError(name)) + return env.schema.getObjectType(name) ?: throw TypeResolverError(getError(name)) } abstract fun getError(name: String): String @@ -33,22 +32,18 @@ internal abstract class DictionaryTypeResolver( internal class InterfaceTypeResolver( dictionary: BiMap>, - private val thisInterface: GraphQLInterfaceType, - types: List + private val thisInterface: GraphQLInterfaceType ) : DictionaryTypeResolver( - dictionary, - types.filter { type -> type.interfaces.any { it.name == thisInterface.name } }.associateBy { it.name } + dictionary ) { override fun getError(name: String) = "Expected object type with name '$name' to implement interface '${thisInterface.name}', but it doesn't!" } internal class UnionTypeResolver( dictionary: BiMap>, - private val thisUnion: GraphQLUnionType, - types: List + private val thisUnion: GraphQLUnionType ) : DictionaryTypeResolver( - dictionary, - types.filter { type -> thisUnion.types.any { it.name == type.name } }.associateBy { it.name } + dictionary ) { override fun getError(name: String) = "Expected object type with name '$name' to exist for union '${thisUnion.name}', but it doesn't!" } diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt index 673d7010..75adc3cb 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt @@ -85,8 +85,8 @@ class SchemaParser internal constructor( val enums = enumDefinitions.map { createEnumObject(it) } // Assign type resolver to interfaces now that we know all of the object types - interfaces.forEach { codeRegistryBuilder.typeResolver(it, InterfaceTypeResolver(dictionary.inverse(), it, objects)) } - unions.forEach { codeRegistryBuilder.typeResolver(it, UnionTypeResolver(dictionary.inverse(), it, objects)) } + interfaces.forEach { codeRegistryBuilder.typeResolver(it, InterfaceTypeResolver(dictionary.inverse(), it)) } + unions.forEach { codeRegistryBuilder.typeResolver(it, UnionTypeResolver(dictionary.inverse(), it)) } // Find query type and mutation/subscription type (if mutation/subscription type exists) val queryName = rootInfo.getQueryName() diff --git a/src/test/kotlin/graphql/kickstart/tools/EndToEndTest.kt b/src/test/kotlin/graphql/kickstart/tools/EndToEndTest.kt index d162c6b0..34cb2a51 100644 --- a/src/test/kotlin/graphql/kickstart/tools/EndToEndTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/EndToEndTest.kt @@ -1,9 +1,11 @@ package graphql.kickstart.tools +import com.fasterxml.jackson.module.kotlin.jacksonMapperBuilder import graphql.* import graphql.execution.AsyncExecutionStrategy -import graphql.schema.GraphQLEnumType -import graphql.schema.GraphQLSchema +import graphql.schema.* +import graphql.util.TraversalControl +import graphql.util.TraverserContext import org.junit.Test import org.reactivestreams.Publisher import org.reactivestreams.Subscriber @@ -670,4 +672,40 @@ class EndToEndTest { val exceptionWhileDataFetching = result.errors[0] as ExceptionWhileDataFetching assert(exceptionWhileDataFetching.exception is IllegalArgumentException) } + + class Transformer : GraphQLTypeVisitorStub() { + override fun visitGraphQLObjectType(node: GraphQLObjectType?, context: TraverserContext?): TraversalControl { + val newNode = node?.transform { builder -> builder.description(node.description + " [MODIFIED]") } + return changeNode(context, newNode) + } + } + + @Test + fun `transformed schema should execute query`() { + val transformedSchema = SchemaTransformer().transform(schema, Transformer()) + val transformedGql: GraphQL = GraphQL.newGraphQL(transformedSchema) + .queryExecutionStrategy(AsyncExecutionStrategy()) + .build() + + val data = assertNoGraphQlErrors(transformedGql) { + """ + { + otherUnionItems { + ... on Item { + itemId: id + } + ... on ThirdItem { + thirdItemId: id + } + } + } + """ + } + + assertEquals(data["otherUnionItems"], listOf( + mapOf("itemId" to 0), + mapOf("itemId" to 1), + mapOf("thirdItemId" to 100) + )) + } }