Skip to content

Commit

Permalink
Support Kotlin extensions in web handlers
Browse files Browse the repository at this point in the history
This commit restores support for Kotlin extensions in
web handlers, and adds support for invoking reflectively
suspending extension functions, as well as the other
features supported as of Spring Framework 6.1 like
value classes and default value for parameters.

Closes gh-31876
  • Loading branch information
sdeleuze committed Dec 21, 2023
1 parent 5f8a031 commit 85cb6cc
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,11 @@ public static Publisher<?> invokeSuspendingFunction(CoroutineContext context, Me
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
case INSTANCE:
argMap.put(parameter, target);
break;
case VALUE:
case EXTENSION_RECEIVER:
if (!parameter.isOptional() || args[index] != null) {
if (parameter.getType().getClassifier() instanceof KClass<?> kClass && kClass.isValue()) {
Class<?> javaClass = JvmClassMappingKt.getJavaClass(kClass);
Expand All @@ -131,7 +134,8 @@ public static Publisher<?> invokeSuspendingFunction(CoroutineContext context, Me
}
}
index++;
}
break;

}
}
return KCallables.callSuspendBy(function, argMap, continuation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,26 @@ class CoroutinesUtilsTests {
}
}

@Test
fun invokeSuspendingFunctionWithExtension() {
val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithExtension",
CustomException::class.java, Continuation::class.java)
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, CustomException("foo")) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingleOrNull()).isEqualTo("foo")
}
}

@Test
fun invokeSuspendingFunctionWithExtensionAndParameter() {
val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithExtensionAndParameter",
CustomException::class.java, Int::class.java, Continuation::class.java)
val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, CustomException("foo"), 20) as Mono
runBlocking {
Assertions.assertThat(mono.awaitSingleOrNull()).isEqualTo("foo-20")
}
}

suspend fun suspendingFunction(value: String): String {
delay(1)
return value
Expand Down Expand Up @@ -186,7 +206,19 @@ class CoroutinesUtilsTests {
return value.value
}

suspend fun CustomException.suspendingFunctionWithExtension(): String {
delay(1)
return "${this.message}"
}

suspend fun CustomException.suspendingFunctionWithExtensionAndParameter(limit: Int): String {
delay(1)
return "${this.message}-$limit"
}

@JvmInline
value class ValueClass(val value: String)

class CustomException(message: String) : Throwable(message)

}
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,11 @@ public static Object invokeFunction(Method method, Object target, Object[] args)
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
case INSTANCE:
argMap.put(parameter, target);
break;
case VALUE:
case EXTENSION_RECEIVER:
if (!parameter.isOptional() || args[index] != null) {
if (parameter.getType().getClassifier() instanceof KClass<?> kClass && kClass.isValue()) {
Class<?> javaClass = JvmClassMappingKt.getJavaClass(kClass);
Expand All @@ -332,7 +335,8 @@ public static Object invokeFunction(Method method, Object target, Object[] args)
}
}
index++;
}
break;

}
}
Object result = function.callBy(argMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@ class InvocableHandlerMethodKotlinTests {
Assertions.assertThat(value).isEqualTo("foo")
}

@Test
fun extension() {
composite.addResolver(StubArgumentResolver(CustomException::class.java, CustomException("foo")))
val value = getInvocable(ExtensionHandler::class.java, CustomException::class.java).invokeForRequest(request, null)
Assertions.assertThat(value).isEqualTo("foo")
}

@Test
fun extensionWithParameter() {
composite.addResolver(StubArgumentResolver(CustomException::class.java, CustomException("foo")))
composite.addResolver(StubArgumentResolver(Int::class.java, 20))
val value = getInvocable(ExtensionHandler::class.java, CustomException::class.java, Int::class.java)
.invokeForRequest(request, null)
Assertions.assertThat(value).isEqualTo("foo-20")
}

private fun getInvocable(clazz: Class<*>, vararg argTypes: Class<*>): InvocableHandlerMethod {
val method = ResolvableMethod.on(clazz).argTypes(*argTypes).resolveMethod()
val handlerMethod = InvocableHandlerMethod(clazz.constructors.first().newInstance(), method)
Expand Down Expand Up @@ -150,10 +166,23 @@ class InvocableHandlerMethodKotlinTests {
get() = "foo"
}

private class ExtensionHandler {

fun CustomException.handle(): String {
return "${this.message}"
}

fun CustomException.handleWithParameter(limit: Int): String {
return "${this.message}-$limit"
}
}

@JvmInline
value class LongValueClass(val value: Long)

@JvmInline
value class DoubleValueClass(val value: Double)

class CustomException(message: String) : Throwable(message)

}
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,11 @@ public static Object invokeFunction(Method method, Object target, Object[] args,
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
case INSTANCE:
argMap.put(parameter, target);
break;
case VALUE:
case EXTENSION_RECEIVER:
if (!parameter.isOptional() || args[index] != null) {
if (parameter.getType().getClassifier() instanceof KClass<?> kClass && kClass.isValue()) {
Class<?> javaClass = JvmClassMappingKt.getJavaClass(kClass);
Expand All @@ -343,7 +346,7 @@ public static Object invokeFunction(Method method, Object target, Object[] args,
}
}
index++;
}
break;
}
}
Object result = function.callBy(argMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import io.mockk.mockk
import kotlinx.coroutines.delay
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.springframework.core.MethodParameter
import org.springframework.core.ReactiveAdapterRegistry
import org.springframework.http.HttpStatus
import org.springframework.http.server.reactive.ServerHttpResponse
Expand All @@ -34,6 +35,7 @@ import org.springframework.web.reactive.result.method.InvocableHandlerMethod
import org.springframework.web.reactive.result.method.annotation.ContinuationHandlerMethodArgumentResolver
import org.springframework.web.reactive.result.method.annotation.RequestParamMethodArgumentResolver
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.get
import org.springframework.web.testfixture.method.ResolvableMethod
import org.springframework.web.testfixture.server.MockServerWebExchange
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
Expand All @@ -55,7 +57,7 @@ class InvocableHandlerMethodKotlinTests {

@Test
fun resolveNoArg() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = CoroutinesController::singleArg.javaMethod!!
val result = invoke(CoroutinesController(), method, null)
assertHandlerResultValue(result, "success:null")
Expand Down Expand Up @@ -116,15 +118,15 @@ class InvocableHandlerMethodKotlinTests {

@Test
fun defaultValue() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = DefaultValueController::handle.javaMethod!!
val result = invoke(DefaultValueController(), method)
assertHandlerResultValue(result, "default")
}

@Test
fun defaultValueOverridden() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = DefaultValueController::handle.javaMethod!!
exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("value", "override"))
val result = invoke(DefaultValueController(), method)
Expand All @@ -133,15 +135,15 @@ class InvocableHandlerMethodKotlinTests {

@Test
fun defaultValues() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, Int::class.java))
val method = DefaultValueController::handleMultiple.javaMethod!!
val result = invoke(DefaultValueController(), method)
assertHandlerResultValue(result, "10-20")
}

@Test
fun defaultValuesOverridden() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, Int::class.java))
val method = DefaultValueController::handleMultiple.javaMethod!!
exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("limit2", "40"))
val result = invoke(DefaultValueController(), method)
Expand All @@ -150,15 +152,15 @@ class InvocableHandlerMethodKotlinTests {

@Test
fun suspendingDefaultValue() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = DefaultValueController::handleSuspending.javaMethod!!
val result = invoke(DefaultValueController(), method)
assertHandlerResultValue(result, "default")
}

@Test
fun suspendingDefaultValueOverridden() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = DefaultValueController::handleSuspending.javaMethod!!
exchange = MockServerWebExchange.from(get("http://localhost:8080/path").queryParam("value", "override"))
val result = invoke(DefaultValueController(), method)
Expand All @@ -181,28 +183,47 @@ class InvocableHandlerMethodKotlinTests {

@Test
fun valueClass() {
this.resolvers.add(stubResolver(1L))
this.resolvers.add(stubResolver(1L, Long::class.java))
val method = ValueClassController::valueClass.javaMethod!!
val result = invoke(ValueClassController(), method,1L)
assertHandlerResultValue(result, "1")
}

@Test
fun valueClassDefaultValue() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, Double::class.java))
val method = ValueClassController::valueClassWithDefault.javaMethod!!
val result = invoke(ValueClassController(), method)
assertHandlerResultValue(result, "3.1")
}

@Test
fun propertyAccessor() {
this.resolvers.add(stubResolver(Mono.empty()))
this.resolvers.add(stubResolver(null, String::class.java))
val method = PropertyAccessorController::prop.getter.javaMethod!!
val result = invoke(PropertyAccessorController(), method)
assertHandlerResultValue(result, "foo")
}

@Test
fun extension() {
this.resolvers.add(stubResolver(CustomException("foo")))
val method = ResolvableMethod.on(ExtensionHandler::class.java).argTypes(CustomException::class.java).resolveMethod()
val result = invoke(ExtensionHandler(), method)
assertHandlerResultValue(result, "foo")
}

@Test
fun extensionWithParameter() {
this.resolvers.add(stubResolver(CustomException("foo")))
this.resolvers.add(stubResolver(20, Int::class.java))
val method = ResolvableMethod.on(ExtensionHandler::class.java)
.argTypes(CustomException::class.java, Int::class.javaPrimitiveType)
.resolveMethod()
val result = invoke(ExtensionHandler(), method)
assertHandlerResultValue(result, "foo-20")
}


private fun invokeForResult(handler: Any, method: Method, vararg providedArgs: Any): HandlerResult? {
return invoke(handler, method, *providedArgs).block(Duration.ofSeconds(5))
Expand All @@ -214,14 +235,13 @@ class InvocableHandlerMethodKotlinTests {
return invocable.invoke(this.exchange, BindingContext(), *providedArgs)
}

private fun stubResolver(stubValue: Any?): HandlerMethodArgumentResolver {
return stubResolver(Mono.justOrEmpty(stubValue))
}
private fun stubResolver(stubValue: Any): HandlerMethodArgumentResolver =
stubResolver(stubValue, stubValue::class.java)

private fun stubResolver(stubValue: Mono<Any>): HandlerMethodArgumentResolver {
private fun stubResolver(stubValue: Any?, stubClass: Class<*>): HandlerMethodArgumentResolver {
val resolver = mockk<HandlerMethodArgumentResolver>()
every { resolver.supportsParameter(any()) } returns true
every { resolver.resolveArgument(any(), any(), any()) } returns stubValue
every { resolver.supportsParameter(any()) } answers { (it.invocation.args[0] as MethodParameter).getParameterType() == stubClass }
every { resolver.resolveArgument(any(), any(), any()) } returns Mono.justOrEmpty(stubValue)
return resolver
}

Expand Down Expand Up @@ -309,9 +329,22 @@ class InvocableHandlerMethodKotlinTests {
get() = "foo"
}

class ExtensionHandler {

fun CustomException.handle(): String {
return "${this.message}"
}

fun CustomException.handleWithParameter(limit: Int): String {
return "${this.message}-$limit"
}
}

@JvmInline
value class LongValueClass(val value: Long)

@JvmInline
value class DoubleValueClass(val value: Double)

class CustomException(message: String) : Throwable(message)
}

0 comments on commit 85cb6cc

Please sign in to comment.