Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
KammererTob committed Jul 4, 2020
1 parent 313afc5 commit a6393e0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 12 deletions.
35 changes: 23 additions & 12 deletions src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class SchemaParser internal constructor(
val inputObjects: MutableList<GraphQLInputObjectType> = mutableListOf()
inputObjectDefinitions.forEach {
if (inputObjects.none { io -> io.name == it.name }) {
inputObjects.add(createInputObject(it, inputObjects))
inputObjects.add(createInputObject(it, inputObjects, mutableSetOf()))
}
}
val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
Expand Down Expand Up @@ -173,7 +173,8 @@ class SchemaParser internal constructor(
return output.toTypedArray()
}

private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLInputObjectType {
private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List<GraphQLInputObjectType>,
referencingInputObjects: MutableSet<String>): GraphQLInputObjectType {
val extensionDefinitions = inputExtensionDefinitions.filter { it.name == definition.name }

val builder = GraphQLInputObjectType.newInputObject()
Expand All @@ -184,14 +185,16 @@ class SchemaParser internal constructor(

builder.withDirectives(*buildDirectives(definition.directives, setOf(), Introspection.DirectiveLocation.INPUT_OBJECT))

referencingInputObjects.add(definition.name)

(extensionDefinitions + definition).forEach {
it.inputValueDefinitions.forEach { inputDefinition ->
val fieldBuilder = GraphQLInputObjectField.newInputObjectField()
.name(inputDefinition.name)
.definition(inputDefinition)
.description(if (inputDefinition.description != null) inputDefinition.description.content else getDocumentation(inputDefinition))
.defaultValue(buildDefaultValue(inputDefinition.defaultValue))
.type(determineInputType(inputDefinition.type, inputObjects))
.type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects))
.withDirectives(*buildDirectives(inputDefinition.directives, setOf(), Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION))
builder.field(fieldBuilder.build())
}
Expand Down Expand Up @@ -297,7 +300,7 @@ class SchemaParser internal constructor(
.definition(argumentDefinition)
.description(if (argumentDefinition.description != null) argumentDefinition.description.content else getDocumentation(argumentDefinition))
.defaultValue(buildDefaultValue(argumentDefinition.defaultValue))
.type(determineInputType(argumentDefinition.type, inputObjects))
.type(determineInputType(argumentDefinition.type, inputObjects, setOf()))
.withDirectives(*buildDirectives(argumentDefinition.directives, setOf(), Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
field.argument(argumentBuilder.build())
}
Expand Down Expand Up @@ -328,7 +331,7 @@ class SchemaParser internal constructor(
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
is InputObjectTypeDefinition -> {
log.info("Create input object")
createInputObject(typeDefinition, inputObjects)
createInputObject(typeDefinition, inputObjects, mutableSetOf())
}
is TypeName -> {
val scalarType = customScalars[typeDefinition.name]
Expand All @@ -346,16 +349,19 @@ class SchemaParser internal constructor(
else -> throw SchemaError("Unknown type: $typeDefinition")
}

private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: Set<String>) =
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects) as GraphQLInputType

private fun <T : Any> determineInputType(expectedType: KClass<T>, typeDefinition: Type<*>, allowedTypeReferences: Set<String>, inputObjects: List<GraphQLInputObjectType>): GraphQLType =
private fun <T : Any> determineInputType(expectedType: KClass<T>,
typeDefinition: Type<*>, allowedTypeReferences: Set<String>,
inputObjects: List<GraphQLInputObjectType>,
referencingInputObjects: Set<String>): GraphQLType =
when (typeDefinition) {
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
is InputObjectTypeDefinition -> {
log.info("Create input object")
createInputObject(typeDefinition, inputObjects)
createInputObject(typeDefinition, inputObjects, referencingInputObjects as MutableSet<String>)
}
is TypeName -> {
val scalarType = customScalars[typeDefinition.name]
Expand All @@ -373,9 +379,14 @@ class SchemaParser internal constructor(
} else {
val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
if (filteredDefinitions.isNotEmpty()) {
val inputObject = createInputObject(filteredDefinitions[0], inputObjects)
(inputObjects as MutableList).add(inputObject)
inputObject
val referencingInputObject = referencingInputObjects.find { it == typeDefinition.name }
if (referencingInputObject != null) {
GraphQLTypeReference(referencingInputObject)
} else {
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet<String>)
(inputObjects as MutableList).add(inputObject)
inputObject
}
} else {
// todo: handle enum type
GraphQLTypeReference(typeDefinition.name)
Expand Down
44 changes: 44 additions & 0 deletions src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,50 @@ class SchemaParserSpec extends Specification {
noExceptionThrown()
}

def "allow circular relations in input objects"() {
when:
SchemaParser.newParser().schemaString('''\
input A {
id: ID!
b: B
}
input B {
id: ID!
a: A
}
input C {
id: ID!
c: C
}
type Query {}
type Mutation {
test(input: A!): Boolean
testC(input: C!): Boolean
}
'''.stripIndent())
.resolvers(new GraphQLMutationResolver() {
static class A {
String id;
B b;
}
static class B {
String id;
A a;
}
static class C {
String id;
C c;
}
boolean test(A a) { return true }
boolean testC(C c) { return true }
}, new GraphQLQueryResolver() {})
.build()
.makeExecutableSchema()

then:
noExceptionThrown()
}

enum EnumType {
TEST
}
Expand Down

0 comments on commit a6393e0

Please sign in to comment.