From 85cb6cc5fb337ec14505b0903834e73da6b45579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Thu, 21 Dec 2023 12:15:26 +0100 Subject: [PATCH] Support Kotlin extensions in web handlers This commit restores support for Kotlin extensions in web handlers, and adds support for invoking reflectively suspending extension functions, as well as the other features supported as of Spring Framework 6.1 like value classes and default value for parameters. Closes gh-31876 --- .../springframework/core/CoroutinesUtils.java | 10 ++- .../core/CoroutinesUtilsTests.kt | 32 +++++++++ .../support/InvocableHandlerMethod.java | 10 ++- .../InvocableHandlerMethodKotlinTests.kt | 29 +++++++++ .../result/method/InvocableHandlerMethod.java | 9 ++- .../InvocableHandlerMethodKotlinTests.kt | 65 ++++++++++++++----- 6 files changed, 130 insertions(+), 25 deletions(-) diff --git a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java index e6de675a306e..8c78ceac9bed 100644 --- a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java +++ b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java @@ -117,8 +117,11 @@ public static Publisher invokeSuspendingFunction(CoroutineContext context, Me int index = 0; for (KParameter parameter : function.getParameters()) { switch (parameter.getKind()) { - case INSTANCE -> argMap.put(parameter, target); - case VALUE -> { + case INSTANCE: + argMap.put(parameter, target); + break; + case VALUE: + case EXTENSION_RECEIVER: if (!parameter.isOptional() || args[index] != null) { if (parameter.getType().getClassifier() instanceof KClass kClass && kClass.isValue()) { Class javaClass = JvmClassMappingKt.getJavaClass(kClass); @@ -131,7 +134,8 @@ public static Publisher invokeSuspendingFunction(CoroutineContext context, Me } } index++; - } + break; + } } return KCallables.callSuspendBy(function, argMap, continuation); diff --git a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt index b5fad73d9ccd..fdce5caf8d85 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt @@ -154,6 +154,26 @@ class CoroutinesUtilsTests { } } + @Test + fun invokeSuspendingFunctionWithExtension() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithExtension", + CustomException::class.java, Continuation::class.java) + val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, CustomException("foo")) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingleOrNull()).isEqualTo("foo") + } + } + + @Test + fun invokeSuspendingFunctionWithExtensionAndParameter() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithExtensionAndParameter", + CustomException::class.java, Int::class.java, Continuation::class.java) + val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, CustomException("foo"), 20) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingleOrNull()).isEqualTo("foo-20") + } + } + suspend fun suspendingFunction(value: String): String { delay(1) return value @@ -186,7 +206,19 @@ class CoroutinesUtilsTests { return value.value } + suspend fun CustomException.suspendingFunctionWithExtension(): String { + delay(1) + return "${this.message}" + } + + suspend fun CustomException.suspendingFunctionWithExtensionAndParameter(limit: Int): String { + delay(1) + return "${this.message}-$limit" + } + @JvmInline value class ValueClass(val value: String) + class CustomException(message: String) : Throwable(message) + } diff --git a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java index ed160be6f55f..303373de3897 100644 --- a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java +++ b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java @@ -318,8 +318,11 @@ public static Object invokeFunction(Method method, Object target, Object[] args) int index = 0; for (KParameter parameter : function.getParameters()) { switch (parameter.getKind()) { - case INSTANCE -> argMap.put(parameter, target); - case VALUE -> { + case INSTANCE: + argMap.put(parameter, target); + break; + case VALUE: + case EXTENSION_RECEIVER: if (!parameter.isOptional() || args[index] != null) { if (parameter.getType().getClassifier() instanceof KClass kClass && kClass.isValue()) { Class javaClass = JvmClassMappingKt.getJavaClass(kClass); @@ -332,7 +335,8 @@ public static Object invokeFunction(Method method, Object target, Object[] args) } } index++; - } + break; + } } Object result = function.callBy(argMap); diff --git a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt index 7ee5d15c88d9..2ba7f249d422 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt @@ -104,6 +104,22 @@ class InvocableHandlerMethodKotlinTests { Assertions.assertThat(value).isEqualTo("foo") } + @Test + fun extension() { + composite.addResolver(StubArgumentResolver(CustomException::class.java, CustomException("foo"))) + val value = getInvocable(ExtensionHandler::class.java, CustomException::class.java).invokeForRequest(request, null) + Assertions.assertThat(value).isEqualTo("foo") + } + + @Test + fun extensionWithParameter() { + composite.addResolver(StubArgumentResolver(CustomException::class.java, CustomException("foo"))) + composite.addResolver(StubArgumentResolver(Int::class.java, 20)) + val value = getInvocable(ExtensionHandler::class.java, CustomException::class.java, Int::class.java) + .invokeForRequest(request, null) + Assertions.assertThat(value).isEqualTo("foo-20") + } + private fun getInvocable(clazz: Class<*>, vararg argTypes: Class<*>): InvocableHandlerMethod { val method = ResolvableMethod.on(clazz).argTypes(*argTypes).resolveMethod() val handlerMethod = InvocableHandlerMethod(clazz.constructors.first().newInstance(), method) @@ -150,10 +166,23 @@ class InvocableHandlerMethodKotlinTests { get() = "foo" } + private class ExtensionHandler { + + fun CustomException.handle(): String { + return "${this.message}" + } + + fun CustomException.handleWithParameter(limit: Int): String { + return "${this.message}-$limit" + } + } + @JvmInline value class LongValueClass(val value: Long) @JvmInline value class DoubleValueClass(val value: Double) + class CustomException(message: String) : Throwable(message) + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index 490e897091d1..0d4c2f47dd7c 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -329,8 +329,11 @@ public static Object invokeFunction(Method method, Object target, Object[] args, int index = 0; for (KParameter parameter : function.getParameters()) { switch (parameter.getKind()) { - case INSTANCE -> argMap.put(parameter, target); - case VALUE -> { + case INSTANCE: + argMap.put(parameter, target); + break; + case VALUE: + case EXTENSION_RECEIVER: if (!parameter.isOptional() || args[index] != null) { if (parameter.getType().getClassifier() instanceof KClass kClass && kClass.isValue()) { Class javaClass = JvmClassMappingKt.getJavaClass(kClass); @@ -343,7 +346,7 @@ public static Object invokeFunction(Method method, Object target, Object[] args, } } index++; - } + break; } } Object result = function.callBy(argMap); diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt index 57ef009b935c..61daeb514006 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt @@ -21,6 +21,7 @@ import io.mockk.mockk import kotlinx.coroutines.delay import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test +import org.springframework.core.MethodParameter import org.springframework.core.ReactiveAdapterRegistry import org.springframework.http.HttpStatus import org.springframework.http.server.reactive.ServerHttpResponse @@ -34,6 +35,7 @@ import org.springframework.web.reactive.result.method.InvocableHandlerMethod import org.springframework.web.reactive.result.method.annotation.ContinuationHandlerMethodArgumentResolver import org.springframework.web.reactive.result.method.annotation.RequestParamMethodArgumentResolver import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.get +import org.springframework.web.testfixture.method.ResolvableMethod import org.springframework.web.testfixture.server.MockServerWebExchange import reactor.core.publisher.Mono import reactor.test.StepVerifier @@ -55,7 +57,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun resolveNoArg() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = CoroutinesController::singleArg.javaMethod!! val result = invoke(CoroutinesController(), method, null) assertHandlerResultValue(result, "success:null") @@ -116,7 +118,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun defaultValue() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = DefaultValueController::handle.javaMethod!! val result = invoke(DefaultValueController(), method) assertHandlerResultValue(result, "default") @@ -124,7 +126,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun defaultValueOverridden() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = DefaultValueController::handle.javaMethod!! exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("value", "override")) val result = invoke(DefaultValueController(), method) @@ -133,7 +135,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun defaultValues() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, Int::class.java)) val method = DefaultValueController::handleMultiple.javaMethod!! val result = invoke(DefaultValueController(), method) assertHandlerResultValue(result, "10-20") @@ -141,7 +143,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun defaultValuesOverridden() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, Int::class.java)) val method = DefaultValueController::handleMultiple.javaMethod!! exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("limit2", "40")) val result = invoke(DefaultValueController(), method) @@ -150,7 +152,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun suspendingDefaultValue() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = DefaultValueController::handleSuspending.javaMethod!! val result = invoke(DefaultValueController(), method) assertHandlerResultValue(result, "default") @@ -158,7 +160,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun suspendingDefaultValueOverridden() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = DefaultValueController::handleSuspending.javaMethod!! exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("value", "override")) val result = invoke(DefaultValueController(), method) @@ -181,7 +183,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun valueClass() { - this.resolvers.add(stubResolver(1L)) + this.resolvers.add(stubResolver(1L, Long::class.java)) val method = ValueClassController::valueClass.javaMethod!! val result = invoke(ValueClassController(), method,1L) assertHandlerResultValue(result, "1") @@ -189,7 +191,7 @@ class InvocableHandlerMethodKotlinTests { @Test fun valueClassDefaultValue() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, Double::class.java)) val method = ValueClassController::valueClassWithDefault.javaMethod!! val result = invoke(ValueClassController(), method) assertHandlerResultValue(result, "3.1") @@ -197,12 +199,31 @@ class InvocableHandlerMethodKotlinTests { @Test fun propertyAccessor() { - this.resolvers.add(stubResolver(Mono.empty())) + this.resolvers.add(stubResolver(null, String::class.java)) val method = PropertyAccessorController::prop.getter.javaMethod!! val result = invoke(PropertyAccessorController(), method) assertHandlerResultValue(result, "foo") } + @Test + fun extension() { + this.resolvers.add(stubResolver(CustomException("foo"))) + val method = ResolvableMethod.on(ExtensionHandler::class.java).argTypes(CustomException::class.java).resolveMethod() + val result = invoke(ExtensionHandler(), method) + assertHandlerResultValue(result, "foo") + } + + @Test + fun extensionWithParameter() { + this.resolvers.add(stubResolver(CustomException("foo"))) + this.resolvers.add(stubResolver(20, Int::class.java)) + val method = ResolvableMethod.on(ExtensionHandler::class.java) + .argTypes(CustomException::class.java, Int::class.javaPrimitiveType) + .resolveMethod() + val result = invoke(ExtensionHandler(), method) + assertHandlerResultValue(result, "foo-20") + } + private fun invokeForResult(handler: Any, method: Method, vararg providedArgs: Any): HandlerResult? { return invoke(handler, method, *providedArgs).block(Duration.ofSeconds(5)) @@ -214,14 +235,13 @@ class InvocableHandlerMethodKotlinTests { return invocable.invoke(this.exchange, BindingContext(), *providedArgs) } - private fun stubResolver(stubValue: Any?): HandlerMethodArgumentResolver { - return stubResolver(Mono.justOrEmpty(stubValue)) - } + private fun stubResolver(stubValue: Any): HandlerMethodArgumentResolver = + stubResolver(stubValue, stubValue::class.java) - private fun stubResolver(stubValue: Mono): HandlerMethodArgumentResolver { + private fun stubResolver(stubValue: Any?, stubClass: Class<*>): HandlerMethodArgumentResolver { val resolver = mockk() - every { resolver.supportsParameter(any()) } returns true - every { resolver.resolveArgument(any(), any(), any()) } returns stubValue + every { resolver.supportsParameter(any()) } answers { (it.invocation.args[0] as MethodParameter).getParameterType() == stubClass } + every { resolver.resolveArgument(any(), any(), any()) } returns Mono.justOrEmpty(stubValue) return resolver } @@ -309,9 +329,22 @@ class InvocableHandlerMethodKotlinTests { get() = "foo" } + class ExtensionHandler { + + fun CustomException.handle(): String { + return "${this.message}" + } + + fun CustomException.handleWithParameter(limit: Int): String { + return "${this.message}-$limit" + } + } + @JvmInline value class LongValueClass(val value: Long) @JvmInline value class DoubleValueClass(val value: Double) + + class CustomException(message: String) : Throwable(message) } \ No newline at end of file