From a6393e0736055070a70bd18dfda49dd645c7cca9 Mon Sep 17 00:00:00 2001 From: Tobias Kammerer Date: Sat, 4 Jul 2020 23:00:07 +0200 Subject: [PATCH] Fix for #409 --- .../graphql/kickstart/tools/SchemaParser.kt | 35 ++++++++++----- .../kickstart/tools/SchemaParserSpec.groovy | 44 +++++++++++++++++++ 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt index d0cd1298..b7405b1e 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt @@ -81,7 +81,7 @@ class SchemaParser internal constructor( val inputObjects: MutableList = mutableListOf() inputObjectDefinitions.forEach { if (inputObjects.none { io -> io.name == it.name }) { - inputObjects.add(createInputObject(it, inputObjects)) + inputObjects.add(createInputObject(it, inputObjects, mutableSetOf())) } } val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) } @@ -173,7 +173,8 @@ class SchemaParser internal constructor( return output.toTypedArray() } - private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List): GraphQLInputObjectType { + private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List, + referencingInputObjects: MutableSet): GraphQLInputObjectType { val extensionDefinitions = inputExtensionDefinitions.filter { it.name == definition.name } val builder = GraphQLInputObjectType.newInputObject() @@ -184,6 +185,8 @@ class SchemaParser internal constructor( builder.withDirectives(*buildDirectives(definition.directives, setOf(), Introspection.DirectiveLocation.INPUT_OBJECT)) + referencingInputObjects.add(definition.name) + (extensionDefinitions + definition).forEach { it.inputValueDefinitions.forEach { inputDefinition -> val fieldBuilder = GraphQLInputObjectField.newInputObjectField() @@ -191,7 +194,7 @@ class SchemaParser internal constructor( .definition(inputDefinition) .description(if (inputDefinition.description != null) inputDefinition.description.content else getDocumentation(inputDefinition)) .defaultValue(buildDefaultValue(inputDefinition.defaultValue)) - .type(determineInputType(inputDefinition.type, inputObjects)) + .type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects)) .withDirectives(*buildDirectives(inputDefinition.directives, setOf(), Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION)) builder.field(fieldBuilder.build()) } @@ -297,7 +300,7 @@ class SchemaParser internal constructor( .definition(argumentDefinition) .description(if (argumentDefinition.description != null) argumentDefinition.description.content else getDocumentation(argumentDefinition)) .defaultValue(buildDefaultValue(argumentDefinition.defaultValue)) - .type(determineInputType(argumentDefinition.type, inputObjects)) + .type(determineInputType(argumentDefinition.type, inputObjects, setOf())) .withDirectives(*buildDirectives(argumentDefinition.directives, setOf(), Introspection.DirectiveLocation.ARGUMENT_DEFINITION)) field.argument(argumentBuilder.build()) } @@ -328,7 +331,7 @@ class SchemaParser internal constructor( is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) is InputObjectTypeDefinition -> { log.info("Create input object") - createInputObject(typeDefinition, inputObjects) + createInputObject(typeDefinition, inputObjects, mutableSetOf()) } is TypeName -> { val scalarType = customScalars[typeDefinition.name] @@ -346,16 +349,19 @@ class SchemaParser internal constructor( else -> throw SchemaError("Unknown type: $typeDefinition") } - private fun determineInputType(typeDefinition: Type<*>, inputObjects: List) = - determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType + private fun determineInputType(typeDefinition: Type<*>, inputObjects: List, referencingInputObjects: Set) = + determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects) as GraphQLInputType - private fun determineInputType(expectedType: KClass, typeDefinition: Type<*>, allowedTypeReferences: Set, inputObjects: List): GraphQLType = + private fun determineInputType(expectedType: KClass, + typeDefinition: Type<*>, allowedTypeReferences: Set, + inputObjects: List, + referencingInputObjects: Set): GraphQLType = when (typeDefinition) { is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) is InputObjectTypeDefinition -> { log.info("Create input object") - createInputObject(typeDefinition, inputObjects) + createInputObject(typeDefinition, inputObjects, referencingInputObjects as MutableSet) } is TypeName -> { val scalarType = customScalars[typeDefinition.name] @@ -373,9 +379,14 @@ class SchemaParser internal constructor( } else { val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name } if (filteredDefinitions.isNotEmpty()) { - val inputObject = createInputObject(filteredDefinitions[0], inputObjects) - (inputObjects as MutableList).add(inputObject) - inputObject + val referencingInputObject = referencingInputObjects.find { it == typeDefinition.name } + if (referencingInputObject != null) { + GraphQLTypeReference(referencingInputObject) + } else { + val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet) + (inputObjects as MutableList).add(inputObject) + inputObject + } } else { // todo: handle enum type GraphQLTypeReference(typeDefinition.name) diff --git a/src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy b/src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy index b1f31f47..85181928 100644 --- a/src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy +++ b/src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy @@ -368,6 +368,50 @@ class SchemaParserSpec extends Specification { noExceptionThrown() } + def "allow circular relations in input objects"() { + when: + SchemaParser.newParser().schemaString('''\ + input A { + id: ID! + b: B + } + input B { + id: ID! + a: A + } + input C { + id: ID! + c: C + } + type Query {} + type Mutation { + test(input: A!): Boolean + testC(input: C!): Boolean + } + '''.stripIndent()) + .resolvers(new GraphQLMutationResolver() { + static class A { + String id; + B b; + } + static class B { + String id; + A a; + } + static class C { + String id; + C c; + } + boolean test(A a) { return true } + boolean testC(C c) { return true } + }, new GraphQLQueryResolver() {}) + .build() + .makeExecutableSchema() + + then: + noExceptionThrown() + } + enum EnumType { TEST }