diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/DirectiveWiringHelper.kt b/src/main/kotlin/graphql/kickstart/tools/directive/DirectiveWiringHelper.kt index 3de8f7ce..0d0a4a0b 100644 --- a/src/main/kotlin/graphql/kickstart/tools/directive/DirectiveWiringHelper.kt +++ b/src/main/kotlin/graphql/kickstart/tools/directive/DirectiveWiringHelper.kt @@ -77,20 +77,21 @@ class DirectiveWiringHelper( private fun wireDirectives(wrapper: WiringWrapper): T { val directivesContainer = wrapper.graphQlType.definition as DirectivesContainer<*> val directives = buildDirectives(directivesContainer.directives, wrapper.directiveLocation) + val directivesByName = directives.associateBy { it.name } var output = wrapper.graphQlType // first the specific named directives - directives.forEach { directive -> - val env = buildEnvironment(wrapper, directives, directive) - val wiring = runtimeWiring.registeredDirectiveWiring[directive.name] + wrapper.graphQlType.appliedDirectives.forEach { appliedDirective -> + val env = buildEnvironment(wrapper, directives, directivesByName[appliedDirective.name], appliedDirective) + val wiring = runtimeWiring.registeredDirectiveWiring[appliedDirective.name] wiring?.let { output = wrapper.invoker(it, env) } } // now call any statically added to the runtime runtimeWiring.directiveWiring.forEach { staticWiring -> - val env = buildEnvironment(wrapper, directives, null) + val env = buildEnvironment(wrapper, directives, null, null) output = wrapper.invoker(staticWiring, env) } // wiring factory is last (if present) - val env = buildEnvironment(wrapper, directives, null) + val env = buildEnvironment(wrapper, directives, null, null) if (runtimeWiring.wiringFactory.providesSchemaDirectiveWiring(env)) { val factoryWiring = runtimeWiring.wiringFactory.getSchemaDirectiveWiring(env) output = wrapper.invoker(factoryWiring, env) @@ -131,7 +132,7 @@ class DirectiveWiringHelper( return output } - private fun buildEnvironment(wrapper: WiringWrapper, directives: List, directive: GraphQLDirective?): SchemaDirectiveWiringEnvironmentImpl { + private fun buildEnvironment(wrapper: WiringWrapper, directives: List, directive: GraphQLDirective?, appliedDirective: GraphQLAppliedDirective?): SchemaDirectiveWiringEnvironmentImpl { val nodeParentTree = buildAstTree(*listOfNotNull( wrapper.fieldsContainer?.definition, wrapper.inputFieldsContainer?.definition, @@ -154,7 +155,7 @@ class DirectiveWiringHelper( is GraphQLFieldsContainer -> schemaDirectiveParameters.newParams(wrapper.graphQlType, nodeParentTree, elementParentTree) else -> schemaDirectiveParameters.newParams(nodeParentTree, elementParentTree) } - return SchemaDirectiveWiringEnvironmentImpl(wrapper.graphQlType, directives, wrapper.graphQlType.appliedDirectives, directive, params) + return SchemaDirectiveWiringEnvironmentImpl(wrapper.graphQlType, directives, wrapper.graphQlType.appliedDirectives, directive, appliedDirective, params) } fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? { diff --git a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.kt b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.kt index 2d9c4675..143fc59d 100644 --- a/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.kt +++ b/src/main/kotlin/graphql/kickstart/tools/directive/SchemaDirectiveWiringEnvironmentImpl.kt @@ -6,13 +6,13 @@ import graphql.schema.* import graphql.schema.idl.RuntimeWiring import graphql.schema.idl.SchemaDirectiveWiringEnvironment import graphql.schema.idl.TypeDefinitionRegistry -import graphql.util.FpKit class SchemaDirectiveWiringEnvironmentImpl( private val element: T, directives: List, appliedDirectives: List, private val registeredDirective: GraphQLDirective?, + private val registeredAppliedDirective: GraphQLAppliedDirective?, parameters: Parameters ) : SchemaDirectiveWiringEnvironment { private val directives: Map @@ -27,8 +27,8 @@ class SchemaDirectiveWiringEnvironmentImpl( init { typeDefinitionRegistry = parameters.typeRegistry - this.directives = FpKit.getByName(directives) { obj: GraphQLDirective -> obj.name } - this.appliedDirectives = FpKit.getByName(appliedDirectives) { obj: GraphQLAppliedDirective -> obj.name } + this.directives = directives.associateBy { it.name } + this.appliedDirectives = appliedDirectives.associateBy { it.name } context = parameters.context codeRegistry = parameters.codeRegistry nodeParentTree = parameters.nodeParentTree @@ -39,7 +39,7 @@ class SchemaDirectiveWiringEnvironmentImpl( override fun getElement(): T = element override fun getDirective(): GraphQLDirective? = registeredDirective - override fun getAppliedDirective(): GraphQLAppliedDirective? = appliedDirectives[registeredDirective?.name] + override fun getAppliedDirective(): GraphQLAppliedDirective? = registeredAppliedDirective override fun getDirectives(): Map = LinkedHashMap(directives) override fun getDirective(directiveName: String): GraphQLDirective = directives[directiveName]!! override fun getAppliedDirectives(): Map = appliedDirectives