diff --git a/pom.xml b/pom.xml index 2f9e9f1f..f6e19888 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ 1.6.21 1.6.1-native-mt 2.13.2.20220328 - 17.3 + 18.0 1.0.3 ${java.version} diff --git a/src/main/kotlin/graphql/kickstart/tools/ScannedSchemaObjects.kt b/src/main/kotlin/graphql/kickstart/tools/ScannedSchemaObjects.kt index ed7e138a..6874504f 100644 --- a/src/main/kotlin/graphql/kickstart/tools/ScannedSchemaObjects.kt +++ b/src/main/kotlin/graphql/kickstart/tools/ScannedSchemaObjects.kt @@ -5,6 +5,7 @@ import graphql.kickstart.tools.util.BiMap import graphql.kickstart.tools.util.JavaType import graphql.language.FieldDefinition import graphql.language.ObjectTypeDefinition +import graphql.language.SDLNamedDefinition import graphql.language.TypeDefinition import graphql.schema.GraphQLScalarType @@ -13,7 +14,7 @@ import graphql.schema.GraphQLScalarType */ internal data class ScannedSchemaObjects( val dictionary: TypeClassDictionary, - val definitions: Set>, + val definitions: Set>, val customScalars: CustomScalarMap, val rootInfo: RootTypeInfo, val fieldResolversByType: Map>, diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt index 2472bef8..ffaeb5e4 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt @@ -33,6 +33,8 @@ internal class SchemaClassScanner( private val initialDictionary = initialDictionary.mapValues { InitialDictionaryEntry(it.value) } private val extensionDefinitions = allDefinitions.filterIsInstance() private val inputExtensionDefinitions = allDefinitions.filterIsInstance() + private val directiveDefinitions = allDefinitions.filterIsInstance() + private val scalarDefinitions = allDefinitions.filterIsInstance() private val definitionsByName = (allDefinitions.filterIsInstance>() - extensionDefinitions - inputExtensionDefinitions).associateBy { it.name } private val objectDefinitions = (allDefinitions.filterIsInstance() - extensionDefinitions) @@ -42,7 +44,7 @@ internal class SchemaClassScanner( private val fieldResolverScanner = FieldResolverScanner(options) private val typeClassMatcher = TypeClassMatcher(definitionsByName) private val dictionary = mutableMapOf, DictionaryEntry>() - private val unvalidatedTypes = mutableSetOf>() + private val unvalidatedTypes = mutableSetOf>(*scalarDefinitions.toTypedArray()) private val queue = linkedSetOf() private val fieldResolversByType = mutableMapOf>() @@ -193,7 +195,9 @@ internal class SchemaClassScanner( validateRootResolversWereUsed(rootTypeHolder.mutation, fieldResolvers) validateRootResolversWereUsed(rootTypeHolder.subscription, fieldResolvers) - return ScannedSchemaObjects(dictionary, observedDefinitions + extensionDefinitions + inputExtensionDefinitions, scalars, rootInfo, fieldResolversByType.toMap(), unusedDefinitions) + val definitions = observedDefinitions + extensionDefinitions + inputExtensionDefinitions + directiveDefinitions + + return ScannedSchemaObjects(dictionary, definitions, scalars, rootInfo, fieldResolversByType.toMap(), unusedDefinitions) } private fun validateRootResolversWereUsed(rootType: RootType?, fieldResolvers: List) { diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaObjects.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaObjects.kt index 9ccd284f..efc53693 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaObjects.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaObjects.kt @@ -1,9 +1,6 @@ package graphql.kickstart.tools -import graphql.schema.GraphQLCodeRegistry -import graphql.schema.GraphQLObjectType -import graphql.schema.GraphQLSchema -import graphql.schema.GraphQLType +import graphql.schema.* /** * @author Andrew Potter @@ -13,6 +10,7 @@ data class SchemaObjects( val mutation: GraphQLObjectType?, val subscription: GraphQLObjectType?, val dictionary: Set, + val directives: Set, val codeRegistryBuilder: GraphQLCodeRegistry.Builder, val description: String? ) { @@ -26,6 +24,7 @@ data class SchemaObjects( .mutation(mutation) .subscription(subscription) .additionalTypes(dictionary) + .additionalDirectives(directives) .codeRegistry(codeRegistryBuilder.build()) .build() } diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt index 6dbf549e..4636e965 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt @@ -45,6 +45,7 @@ class SchemaParser internal constructor( private val inputObjectDefinitions = (definitions.filterIsInstance() - inputExtensionDefinitions) private val enumDefinitions = definitions.filterIsInstance() private val interfaceDefinitions = definitions.filterIsInstance() + private val directiveDefinitions = definitions.filterIsInstance() private val unionDefinitions = definitions.filterIsInstance() @@ -82,6 +83,8 @@ class SchemaParser internal constructor( val unions = unionDefinitions.map { createUnionObject(it, objects) } val enums = enumDefinitions.map { createEnumObject(it) } + val directives = directiveDefinitions.map { createDirective(it, inputObjects) }.toSet() + // Assign type resolver to interfaces now that we know all of the object types interfaces.forEach { codeRegistryBuilder.typeResolver(it, InterfaceTypeResolver(dictionary.inverse(), it)) } unions.forEach { codeRegistryBuilder.typeResolver(it, UnionTypeResolver(dictionary.inverse(), it)) } @@ -101,7 +104,7 @@ class SchemaParser internal constructor( val additionalObjects = objects.filter { o -> o != query && o != subscription && o != mutation } val types = (additionalObjects.toSet() as Set) + inputObjects + enums + interfaces + unions - return SchemaObjects(query, mutation, subscription, types, codeRegistryBuilder, rootInfo.getDescription()) + return SchemaObjects(query, mutation, subscription, types, directives, codeRegistryBuilder, rootInfo.getDescription()) } /** @@ -123,6 +126,7 @@ class SchemaParser internal constructor( .description(getDocumentation(objectDefinition, options)) builder.withDirectives(*buildDirectives(objectDefinition.directives, Introspection.DirectiveLocation.OBJECT)) + builder.withAppliedDirectives(*buildAppliedDirectives(objectDefinition.directives)) objectDefinition.implements.forEach { implementsDefinition -> val interfaceName = (implementsDefinition as TypeName).name @@ -163,6 +167,7 @@ class SchemaParser internal constructor( .description(getDocumentation(definition, options)) builder.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.INPUT_OBJECT)) + builder.withAppliedDirectives(*buildAppliedDirectives(definition.directives)) referencingInputObjects.add(definition.name) @@ -175,6 +180,7 @@ class SchemaParser internal constructor( .apply { inputDefinition.defaultValue?.let { v -> defaultValueLiteral(v) } } .type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects)) .withDirectives(*buildDirectives(inputDefinition.directives, Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION)) + .withAppliedDirectives(*buildAppliedDirectives(inputDefinition.directives)) builder.field(fieldBuilder.build()) } } @@ -194,6 +200,7 @@ class SchemaParser internal constructor( .description(getDocumentation(definition, options)) builder.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.ENUM)) + builder.withAppliedDirectives(*buildAppliedDirectives(definition.directives)) definition.enumValueDefinitions.forEach { enumDefinition -> val enumName = enumDefinition.name @@ -201,6 +208,7 @@ class SchemaParser internal constructor( ?: throw SchemaError("Expected value for name '$enumName' in enum '${type.unwrap().simpleName}' but found none!") val enumValueDirectives = buildDirectives(enumDefinition.directives, Introspection.DirectiveLocation.ENUM_VALUE) + val enumValueAppliedDirectives = buildAppliedDirectives(enumDefinition.directives) getDeprecated(enumDefinition.directives).let { val enumValueDefinition = GraphQLEnumValueDefinition.newEnumValueDefinition() .name(enumName) @@ -208,6 +216,7 @@ class SchemaParser internal constructor( .value(enumValue) .deprecationReason(it) .withDirectives(*enumValueDirectives) + .withAppliedDirectives(*enumValueAppliedDirectives) .definition(enumDefinition) .build() @@ -226,6 +235,7 @@ class SchemaParser internal constructor( .description(getDocumentation(interfaceDefinition, options)) builder.withDirectives(*buildDirectives(interfaceDefinition.directives, Introspection.DirectiveLocation.INTERFACE)) + builder.withAppliedDirectives(*buildAppliedDirectives(interfaceDefinition.directives)) interfaceDefinition.fieldDefinitions.forEach { fieldDefinition -> builder.field { field -> createField(field, fieldDefinition, inputObjects) } @@ -247,6 +257,7 @@ class SchemaParser internal constructor( .description(getDocumentation(definition, options)) builder.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.UNION)) + builder.withAppliedDirectives(*buildAppliedDirectives(definition.directives)) getLeafUnionObjects(definition, types).forEach { builder.possibleType(it) } return schemaGeneratorDirectiveHelper.onUnion(builder.build(), schemaDirectiveParameters) @@ -288,14 +299,44 @@ class SchemaParser internal constructor( .type(determineInputType(argumentDefinition.type, inputObjects, setOf())) .apply { argumentDefinition.defaultValue?.let { defaultValueLiteral(it) } } .withDirectives(*buildDirectives(argumentDefinition.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION)) + .withAppliedDirectives(*buildAppliedDirectives(argumentDefinition.directives)) field.argument(argumentBuilder.build()) } field.withDirectives(*buildDirectives(fieldDefinition.directives, Introspection.DirectiveLocation.FIELD_DEFINITION)) + field.withAppliedDirectives(*buildAppliedDirectives(fieldDefinition.directives)) return field } + private fun createDirective(definition: DirectiveDefinition, inputObjects: List): GraphQLDirective { + val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray() + + val graphQLDirective = GraphQLDirective.newDirective() + .name(definition.name) + .description(getDocumentation(definition, options)) + .definition(definition) + .comparatorRegistry(runtimeWiring.comparatorRegistry) + .validLocations(*locations) + .repeatable(definition.isRepeatable) + .apply { + definition.inputValueDefinitions.forEach { arg -> + argument(GraphQLArgument.newArgument() + .name(arg.name) + .definition(arg) + .description(getDocumentation(arg, options)) + .type(determineInputType(arg.type, inputObjects, setOf())) + .apply { arg.defaultValue?.let { defaultValueLiteral(it) } } + .withDirectives(*buildDirectives(arg.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION)) + .withAppliedDirectives(*buildAppliedDirectives(arg.directives)) + .build()) + } + } + .build() + + return graphQLDirective + } + private fun buildDirectives(directives: List, directiveLocation: Introspection.DirectiveLocation): Array { val names = mutableSetOf() @@ -326,14 +367,43 @@ class SchemaParser internal constructor( return output.toTypedArray() } + private fun buildAppliedDirectives(directives: List): Array { + val names = mutableSetOf() + + val output = mutableListOf() + for (directive in directives) { + if (!names.contains(directive.name)) { + names.add(directive.name) + val graphQLDirective = GraphQLAppliedDirective.newDirective() + .name(directive.name) + .description(getDocumentation(directive, options)) + .comparatorRegistry(runtimeWiring.comparatorRegistry) + .apply { + directive.arguments.forEach { arg -> + argument(GraphQLAppliedDirectiveArgument.newArgument() + .name(arg.name) + .type(buildDirectiveInputType(arg.value)) + .valueLiteral(arg.value) + .build()) + } + } + .build() + + output.add(graphQLDirective) + } + } + + return output.toTypedArray() + } + private fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? { - when (value) { - is NullValue -> return Scalars.GraphQLString - is FloatValue -> return Scalars.GraphQLFloat - is StringValue -> return Scalars.GraphQLString - is IntValue -> return Scalars.GraphQLInt - is BooleanValue -> return Scalars.GraphQLBoolean - is ArrayValue -> return GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value))) + return when (value) { + is NullValue -> Scalars.GraphQLString + is FloatValue -> Scalars.GraphQLFloat + is StringValue -> Scalars.GraphQLString + is IntValue -> Scalars.GraphQLInt + is BooleanValue -> Scalars.GraphQLBoolean + is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value))) else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.") } } @@ -448,14 +518,10 @@ class SchemaParser internal constructor( * indicating no deprecation directive was found within the directives list. */ private fun getDeprecated(directives: List): String? = - getDirective(directives, "deprecated")?.let { directive -> + directives.find { it.name == "deprecated" }?.let { directive -> (directive.arguments.find { it.name == "reason" }?.value as? StringValue)?.value ?: DEFAULT_DEPRECATION_MESSAGE } - - private fun getDirective(directives: List, name: String): Directive? = directives.find { - it.name == name - } } class SchemaError(message: String, cause: Throwable? = null) : RuntimeException(message, cause) diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.java b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.java index 291f201b..6544a2b4 100644 --- a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.java +++ b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.java @@ -3,14 +3,7 @@ import graphql.Internal; import graphql.language.NamedNode; import graphql.language.NodeParentTree; -import graphql.schema.DataFetcher; -import graphql.schema.FieldCoordinates; -import graphql.schema.GraphQLCodeRegistry; -import graphql.schema.GraphQLDirective; -import graphql.schema.GraphQLDirectiveContainer; -import graphql.schema.GraphQLFieldDefinition; -import graphql.schema.GraphQLFieldsContainer; -import graphql.schema.GraphqlElementParentTree; +import graphql.schema.*; import graphql.schema.idl.SchemaDirectiveWiringEnvironment; import graphql.schema.idl.TypeDefinitionRegistry; import graphql.util.FpKit; @@ -31,6 +24,7 @@ public class SchemaDirectiveWiringEnvironmentImpl directives; + private final Map appliedDirectives; private final NodeParentTree> nodeParentTree; private final TypeDefinitionRegistry typeDefinitionRegistry; private final Map context; @@ -40,11 +34,18 @@ public class SchemaDirectiveWiringEnvironmentImpl directives, GraphQLDirective registeredDirective, SchemaGeneratorDirectiveHelper.Parameters parameters) { + public SchemaDirectiveWiringEnvironmentImpl( + T element, + List directives, + List appliedDirectives, + GraphQLDirective registeredDirective, + SchemaGeneratorDirectiveHelper.Parameters parameters + ) { this.element = element; this.registeredDirective = registeredDirective; this.typeDefinitionRegistry = parameters.getTypeRegistry(); this.directives = FpKit.getByName(directives, GraphQLDirective::getName); + this.appliedDirectives = FpKit.getByName(appliedDirectives, GraphQLAppliedDirective::getName); this.context = parameters.getContext(); this.codeRegistry = parameters.getCodeRegistry(); this.nodeParentTree = parameters.getNodeParentTree(); @@ -73,6 +74,16 @@ public GraphQLDirective getDirective(String directiveName) { return directives.get(directiveName); } + @Override + public Map getAppliedDirectives() { + return appliedDirectives; + } + + @Override + public GraphQLAppliedDirective getAppliedDirective(String directiveName) { + return appliedDirectives.get(directiveName); + } + @Override public boolean containsDirective(String directiveName) { return directives.containsKey(directiveName); diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java index 03482fe0..d4f48692 100644 --- a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java +++ b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaGeneratorDirectiveHelper.java @@ -3,29 +3,10 @@ import graphql.Internal; import graphql.language.NamedNode; import graphql.language.NodeParentTree; -import graphql.schema.GraphQLArgument; -import graphql.schema.GraphQLCodeRegistry; -import graphql.schema.GraphQLDirective; -import graphql.schema.GraphQLDirectiveContainer; -import graphql.schema.GraphQLEnumType; -import graphql.schema.GraphQLEnumValueDefinition; -import graphql.schema.GraphQLFieldDefinition; -import graphql.schema.GraphQLFieldsContainer; -import graphql.schema.GraphQLInputObjectField; -import graphql.schema.GraphQLInputObjectType; -import graphql.schema.GraphQLInterfaceType; -import graphql.schema.GraphQLObjectType; -import graphql.schema.GraphQLScalarType; -import graphql.schema.GraphQLSchemaElement; -import graphql.schema.GraphQLUnionType; -import graphql.schema.GraphqlElementParentTree; +import graphql.schema.*; import graphql.schema.idl.*; -import java.util.ArrayDeque; -import java.util.Deque; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import static graphql.Assert.assertNotNull; import static graphql.collect.ImmutableKit.map; @@ -67,7 +48,11 @@ public static boolean schemaDirectiveWirin } Parameters params = new Parameters(typeRegistry, runtimeWiring, new HashMap<>(), null); - SchemaDirectiveWiringEnvironment env = new SchemaDirectiveWiringEnvironmentImpl<>(directiveContainer, directiveContainer.getDirectives(), null, params); + SchemaDirectiveWiringEnvironment env = new SchemaDirectiveWiringEnvironmentImpl<>(directiveContainer, + directiveContainer.getDirectives(), + directiveContainer.getAppliedDirectives(), + null, + params); // do they dynamically provide a wiring for this element? return wiringFactory.providesSchemaDirectiveWiring(env); } @@ -204,8 +189,16 @@ public GraphQLObjectType onObject(GraphQLObjectType objectType, Parameters param GraphqlElementParentTree elementParentTree = buildRuntimeTree(newObjectType); Parameters newParams = params.newParams(newObjectType, nodeParentTree, elementParentTree); - return wireDirectives(params, newObjectType, newObjectType.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, newParams), SchemaDirectiveWiring::onObject); + return wireDirectives(params, + newObjectType, + newObjectType.getDirectives(), + newObjectType.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + newParams), + SchemaDirectiveWiring::onObject); } public GraphQLInterfaceType onInterface(GraphQLInterfaceType interfaceType, Parameters params) { @@ -221,8 +214,16 @@ public GraphQLInterfaceType onInterface(GraphQLInterfaceType interfaceType, Para GraphqlElementParentTree elementParentTree = buildRuntimeTree(newInterfaceType); Parameters newParams = params.newParams(newInterfaceType, nodeParentTree, elementParentTree); - return wireDirectives(params, newInterfaceType, newInterfaceType.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, newParams), SchemaDirectiveWiring::onInterface); + return wireDirectives(params, + newInterfaceType, + newInterfaceType.getDirectives(), + newInterfaceType.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + newParams), + SchemaDirectiveWiring::onInterface); } public GraphQLEnumType onEnum(final GraphQLEnumType enumType, Parameters params) { @@ -247,8 +248,16 @@ public GraphQLEnumType onEnum(final GraphQLEnumType enumType, Parameters params) GraphqlElementParentTree elementParentTree = buildRuntimeTree(newEnumType); Parameters newParams = params.newParams(nodeParentTree, elementParentTree); - return wireDirectives(params, newEnumType, newEnumType.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, newParams), SchemaDirectiveWiring::onEnum); + return wireDirectives(params, + newEnumType, + newEnumType.getDirectives(), + newEnumType.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + newParams), + SchemaDirectiveWiring::onEnum); } public GraphQLInputObjectType onInputObjectType(GraphQLInputObjectType inputObjectType, Parameters params) { @@ -271,8 +280,16 @@ public GraphQLInputObjectType onInputObjectType(GraphQLInputObjectType inputObje GraphqlElementParentTree elementParentTree = buildRuntimeTree(newInputObjectType); Parameters newParams = params.newParams(nodeParentTree, elementParentTree); - return wireDirectives(params, newInputObjectType, newInputObjectType.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, newParams), SchemaDirectiveWiring::onInputObjectType); + return wireDirectives(params, + newInputObjectType, + newInputObjectType.getDirectives(), + newInputObjectType.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + newParams), + SchemaDirectiveWiring::onInputObjectType); } @@ -281,8 +298,16 @@ public GraphQLUnionType onUnion(GraphQLUnionType element, Parameters params) { GraphqlElementParentTree elementParentTree = buildRuntimeTree(element); Parameters newParams = params.newParams(nodeParentTree, elementParentTree); - return wireDirectives(params, element, element.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, newParams), SchemaDirectiveWiring::onUnion); + return wireDirectives(params, + element, + element.getDirectives(), + element.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + newParams), + SchemaDirectiveWiring::onUnion); } public GraphQLScalarType onScalar(GraphQLScalarType element, Parameters params) { @@ -290,36 +315,81 @@ public GraphQLScalarType onScalar(GraphQLScalarType element, Parameters params) GraphqlElementParentTree elementParentTree = buildRuntimeTree(element); Parameters newParams = params.newParams(nodeParentTree, elementParentTree); - return wireDirectives(params, element, element.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, newParams), SchemaDirectiveWiring::onScalar); + return wireDirectives(params, + element, + element.getDirectives(), + element.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + newParams), + SchemaDirectiveWiring::onScalar); } private GraphQLFieldDefinition onField(GraphQLFieldDefinition fieldDefinition, Parameters params) { - return wireDirectives(params, fieldDefinition, fieldDefinition.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, params), SchemaDirectiveWiring::onField); + return wireDirectives(params, + fieldDefinition, + fieldDefinition.getDirectives(), + fieldDefinition.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + params), + SchemaDirectiveWiring::onField); } private GraphQLInputObjectField onInputObjectField(GraphQLInputObjectField element, Parameters params) { - return wireDirectives(params, element, element.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, params), SchemaDirectiveWiring::onInputObjectField); + return wireDirectives(params, + element, + element.getDirectives(), + element.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + params), + SchemaDirectiveWiring::onInputObjectField); } private GraphQLEnumValueDefinition onEnumValue(GraphQLEnumValueDefinition enumValueDefinition, Parameters params) { - return wireDirectives(params, enumValueDefinition, enumValueDefinition.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, params), SchemaDirectiveWiring::onEnumValue); + return wireDirectives(params, + enumValueDefinition, + enumValueDefinition.getDirectives(), + enumValueDefinition.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + params), + SchemaDirectiveWiring::onEnumValue); } private GraphQLArgument onArgument(GraphQLArgument argument, Parameters params) { - return wireDirectives(params, argument, argument.getDirectives(), - (outputElement, directives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, directives, registeredDirective, params), SchemaDirectiveWiring::onArgument); + return wireDirectives(params, + argument, + argument.getDirectives(), + argument.getAppliedDirectives(), + (outputElement, directives, appliedDirectives, registeredDirective) -> new SchemaDirectiveWiringEnvironmentImpl<>(outputElement, + directives, + appliedDirectives, + registeredDirective, + params), + SchemaDirectiveWiring::onArgument); } - // // builds a type safe SchemaDirectiveWiringEnvironment // interface EnvBuilder { - SchemaDirectiveWiringEnvironment apply(T outputElement, List allDirectives, GraphQLDirective registeredDirective); + + SchemaDirectiveWiringEnvironment apply( + T outputElement, + List allDirectives, + List allAppliedDirectives, + GraphQLDirective registeredDirective + ); } // @@ -330,10 +400,12 @@ interface EnvInvoker { } private T wireDirectives( - Parameters parameters, T element, - List allDirectives, - EnvBuilder envBuilder, - EnvInvoker invoker) { + Parameters parameters, T element, + List allDirectives, + List allAppliedDirectives, + EnvBuilder envBuilder, + EnvInvoker invoker + ) { RuntimeWiring runtimeWiring = parameters.getRuntimeWiring(); WiringFactory wiringFactory = runtimeWiring.getWiringFactory(); @@ -347,19 +419,19 @@ private T wireDirectives( for (GraphQLDirective directive : allDirectives) { schemaDirectiveWiring = mapOfWiring.get(directive.getName()); if (schemaDirectiveWiring != null) { - env = envBuilder.apply(outputObject, allDirectives, directive); + env = envBuilder.apply(outputObject, allDirectives, allAppliedDirectives, directive); outputObject = invokeWiring(outputObject, invoker, schemaDirectiveWiring, env); } } // - // now call any statically added to the the runtime + // now call any statically added to the runtime for (SchemaDirectiveWiring directiveWiring : runtimeWiring.getDirectiveWiring()) { - env = envBuilder.apply(outputObject, allDirectives, null); + env = envBuilder.apply(outputObject, allDirectives, allAppliedDirectives, null); outputObject = invokeWiring(outputObject, invoker, directiveWiring, env); } // // wiring factory is last (if present) - env = envBuilder.apply(outputObject, allDirectives, null); + env = envBuilder.apply(outputObject, allDirectives, allAppliedDirectives, null); if (wiringFactory.providesSchemaDirectiveWiring(env)) { schemaDirectiveWiring = assertNotNull(wiringFactory.getSchemaDirectiveWiring(env), () -> "Your WiringFactory MUST provide a non null SchemaDirectiveWiring"); outputObject = invokeWiring(outputObject, invoker, schemaDirectiveWiring, env); @@ -391,5 +463,4 @@ private boolean isNotTheSameObjects(List starting, List ending) { } return false; } - } diff --git a/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt index f5893618..576d1807 100644 --- a/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/SchemaClassScannerTest.kt @@ -434,6 +434,11 @@ class SchemaClassScannerTest { # these directives are defined in the Apollo Federation Specification: # https://www.apollographql.com/docs/apollo-server/federation/federation-spec/ + scalar _FieldSet + directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE + directive @extends on OBJECT | INTERFACE + directive @external on FIELD_DEFINITION + type User @key(fields: "id") @extends { id: ID! @external recentPurchasedProducts: [Product] @@ -449,6 +454,7 @@ class SchemaClassScannerTest { }) .options(SchemaParserOptions.newOptions().includeUnusedTypes(true).build()) .dictionary(User::class) + .scalars(fieldSet) .build() .makeExecutableSchema() @@ -471,6 +477,16 @@ class SchemaClassScannerTest { var street: String? = null } + private val fieldSet: GraphQLScalarType = GraphQLScalarType.newScalar() + .name("_FieldSet") + .description("_FieldSet") + .coercing(object : Coercing { + override fun parseValue(input: Any) = input.toString() + override fun serialize(dataFetcherResult: Any) = dataFetcherResult as String + override fun parseLiteral(input: Any) = input.toString() + }) + .build() + @Test fun `scanner should handle unused types with interfaces when option is true`() { val schema = SchemaParser.newParser() diff --git a/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt b/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt index d382a3cb..4df062d7 100644 --- a/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt @@ -258,7 +258,7 @@ class SchemaParserTest { val sourceLocation = schema.getObjectType("Query") .getFieldDefinition("id") - .definition.sourceLocation + .definition!!.sourceLocation assertNotNull(sourceLocation) assertEquals(sourceLocation.line, 2) assertEquals(sourceLocation.column, 5) @@ -275,7 +275,7 @@ class SchemaParserTest { val sourceLocation = schema.getObjectType("Query") .getFieldDefinition("id") - .definition.sourceLocation + .definition!!.sourceLocation assertNotNull(sourceLocation) assertEquals(sourceLocation.line, 2) assertEquals(sourceLocation.column, 3) @@ -441,8 +441,8 @@ class SchemaParserTest { assert(poodleTraitObject.interfaces.containsAll(listOf(mammalTraitInterface, traitInterface))) assert(dogInterface.interfaces.contains(animalInterface)) assert(mammalTraitInterface.interfaces.contains(traitInterface)) - assert(traitInterface.definition.implements.isEmpty()) - assert(animalInterface.definition.implements.isEmpty()) + assert(traitInterface.definition!!.implements.isEmpty()) + assert(animalInterface.definition!!.implements.isEmpty()) } class MultiLevelInterfaceResolver : GraphQLQueryResolver { diff --git a/src/test/resources/RelayConnection.graphqls b/src/test/resources/RelayConnection.graphqls index 04aee5ed..e08d5099 100644 --- a/src/test/resources/RelayConnection.graphqls +++ b/src/test/resources/RelayConnection.graphqls @@ -1,8 +1,10 @@ +directive @connection(for: String!) on FIELD + type Query { - users(first: Int, after: String): UserRelayConnection @connection(for: "User") + users(first: Int, after: String): UserRelayConnection @connection(for: "User") } type User { - id: ID! - name: String + id: ID! + name: String }