From 516a20370383e8ef4bdf09b31e290628b97b877a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Sun, 3 Mar 2024 22:23:57 +0100 Subject: [PATCH] Support nullable Kotlin value class arguments This commit skips the value class parameter instantiation for nullable types when a null argument is passed. Closes gh-32353 --- .../springframework/core/CoroutinesUtils.java | 12 ++++---- .../core/CoroutinesUtilsTests.kt | 30 ++++++++++++++++++- .../support/InvocableHandlerMethod.java | 12 ++++---- .../InvocableHandlerMethodKotlinTests.kt | 10 +++++++ .../result/method/InvocableHandlerMethod.java | 12 ++++---- .../InvocableHandlerMethodKotlinTests.kt | 30 ++++++++++++++++--- 6 files changed, 86 insertions(+), 20 deletions(-) diff --git a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java index 0599f824ead9..687101c5b075 100644 --- a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java +++ b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java @@ -115,18 +115,20 @@ public static Publisher invokeSuspendingFunction(CoroutineContext context, Me switch (parameter.getKind()) { case INSTANCE -> argMap.put(parameter, target); case VALUE, EXTENSION_RECEIVER -> { - if (!parameter.isOptional() || args[index] != null) { + Object arg = args[index]; + if (!(parameter.isOptional() && arg == null)) { if (parameter.getType().getClassifier() instanceof KClass kClass) { Class javaClass = JvmClassMappingKt.getJavaClass(kClass); - if (KotlinDetector.isInlineClass(javaClass)) { - argMap.put(parameter, KClasses.getPrimaryConstructor(kClass).call(args[index])); + if (KotlinDetector.isInlineClass(javaClass) + && !(parameter.getType().isMarkedNullable() && arg == null)) { + argMap.put(parameter, KClasses.getPrimaryConstructor(kClass).call(arg)); } else { - argMap.put(parameter, args[index]); + argMap.put(parameter, arg); } } else { - argMap.put(parameter, args[index]); + argMap.put(parameter, arg); } } index++; diff --git a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt index c92e60366195..2091fbe0dd80 100644 --- a/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt +++ b/spring-core/src/test/kotlin/org/springframework/core/CoroutinesUtilsTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -82,6 +82,15 @@ class CoroutinesUtilsTests { .verify() } + @Test + fun invokeSuspendingFunctionWithNullableParameter() { + val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithNullable", String::class.java, Continuation::class.java) + val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, null, null) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingleOrNull()).isNull() + } + } + @Test fun invokeNonSuspendingFunction() { val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("nonSuspendingFunction", String::class.java) @@ -165,6 +174,15 @@ class CoroutinesUtilsTests { } } + @Test + fun invokeSuspendingFunctionWithNullableValueClassParameter() { + val method = CoroutinesUtilsTests::class.java.declaredMethods.first { it.name.startsWith("suspendingFunctionWithNullableValueClass") } + val mono = CoroutinesUtils.invokeSuspendingFunction(method, this, null, null) as Mono + runBlocking { + Assertions.assertThat(mono.awaitSingleOrNull()).isNull() + } + } + @Test fun invokeSuspendingFunctionWithExtension() { val method = CoroutinesUtilsTests::class.java.getDeclaredMethod("suspendingFunctionWithExtension", @@ -190,6 +208,11 @@ class CoroutinesUtilsTests { return value } + suspend fun suspendingFunctionWithNullable(value: String?): String? { + delay(1) + return value + } + suspend fun suspendingFunctionWithFlow(): Flow { delay(1) return flowOf("foo", "bar") @@ -222,6 +245,11 @@ class CoroutinesUtilsTests { return value.value } + suspend fun suspendingFunctionWithNullableValueClass(value: ValueClass?): String? { + delay(1) + return value?.value + } + suspend fun CustomException.suspendingFunctionWithExtension(): String { delay(1) return "${this.message}" diff --git a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java index caa6b85417b3..5bdc816d23f7 100644 --- a/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java +++ b/spring-web/src/main/java/org/springframework/web/method/support/InvocableHandlerMethod.java @@ -313,18 +313,20 @@ public static Object invokeFunction(Method method, Object target, Object[] args) switch (parameter.getKind()) { case INSTANCE -> argMap.put(parameter, target); case VALUE, EXTENSION_RECEIVER -> { - if (!parameter.isOptional() || args[index] != null) { + Object arg = args[index]; + if (!(parameter.isOptional() && arg == null)) { if (parameter.getType().getClassifier() instanceof KClass kClass) { Class javaClass = JvmClassMappingKt.getJavaClass(kClass); - if (KotlinDetector.isInlineClass(javaClass)) { - argMap.put(parameter, KClasses.getPrimaryConstructor(kClass).call(args[index])); + if (KotlinDetector.isInlineClass(javaClass) + && !(parameter.getType().isMarkedNullable() && arg == null)) { + argMap.put(parameter, KClasses.getPrimaryConstructor(kClass).call(arg)); } else { - argMap.put(parameter, args[index]); + argMap.put(parameter, arg); } } else { - argMap.put(parameter, args[index]); + argMap.put(parameter, arg); } } index++; diff --git a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt index e23284d8c31e..a71fceb1135e 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/method/support/InvocableHandlerMethodKotlinTests.kt @@ -105,6 +105,13 @@ class InvocableHandlerMethodKotlinTests { Assertions.assertThatIllegalArgumentException().isThrownBy { invocable.invokeForRequest(request, null) } } + @Test + fun valueClassWithNullable() { + composite.addResolver(StubArgumentResolver(LongValueClass::class.java, null)) + val value = getInvocable(ValueClassHandler::class.java, LongValueClass::class.java).invokeForRequest(request, null) + Assertions.assertThat(value).isNull() + } + @Test fun propertyAccessor() { val value = getInvocable(PropertyAccessorHandler::class.java).invokeForRequest(request, null) @@ -173,6 +180,9 @@ class InvocableHandlerMethodKotlinTests { fun valueClassWithInit(valueClass: ValueClassWithInit) = valueClass + fun valueClassWithNullable(limit: LongValueClass?) = + limit?.value + } private class PropertyAccessorHandler { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index 1b9501c28d64..bab33e84d9fc 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -324,18 +324,20 @@ public static Object invokeFunction(Method method, Object target, Object[] args, switch (parameter.getKind()) { case INSTANCE -> argMap.put(parameter, target); case VALUE, EXTENSION_RECEIVER -> { - if (!parameter.isOptional() || args[index] != null) { + Object arg = args[index]; + if (!(parameter.isOptional() && arg == null)) { if (parameter.getType().getClassifier() instanceof KClass kClass) { Class javaClass = JvmClassMappingKt.getJavaClass(kClass); - if (KotlinDetector.isInlineClass(javaClass)) { - argMap.put(parameter, KClasses.getPrimaryConstructor(kClass).call(args[index])); + if (KotlinDetector.isInlineClass(javaClass) + && !(parameter.getType().isMarkedNullable() && arg == null)) { + argMap.put(parameter, KClasses.getPrimaryConstructor(kClass).call(arg)); } else { - argMap.put(parameter, args[index]); + argMap.put(parameter, arg); } } else { - argMap.put(parameter, args[index]); + argMap.put(parameter, arg); } } index++; diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt index d021c9a6d29f..22e55f1d5390 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/InvocableHandlerMethodKotlinTests.kt @@ -19,7 +19,6 @@ package org.springframework.web.reactive.result import io.mockk.every import io.mockk.mockk import kotlinx.coroutines.delay -import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.springframework.core.MethodParameter @@ -178,11 +177,19 @@ class InvocableHandlerMethodKotlinTests { @Test fun nullReturnValue() { - val method = NullResultController::nullable.javaMethod!! + val method = NullResultController::nullableReturnValue.javaMethod!! val result = invoke(NullResultController(), method) assertHandlerResultValue(result, null) } + @Test + fun nullParameter() { + this.resolvers.add(stubResolver(null, String::class.java)) + val method = NullResultController::nullableParameter.javaMethod!! + val result = invoke(NullResultController(), method, null) + assertHandlerResultValue(result, null) + } + @Test fun valueClass() { this.resolvers.add(stubResolver(1L, Long::class.java)) @@ -192,7 +199,7 @@ class InvocableHandlerMethodKotlinTests { } @Test - fun valueClassDefaultValue() { + fun valueClassWithDefaultValue() { this.resolvers.add(stubResolver(null, Double::class.java)) val method = ValueClassController::valueClassWithDefault.javaMethod!! val result = invoke(ValueClassController(), method) @@ -207,6 +214,14 @@ class InvocableHandlerMethodKotlinTests { assertExceptionThrown(result, IllegalArgumentException::class) } + @Test + fun valueClassWithNullable() { + this.resolvers.add(stubResolver(null, LongValueClass::class.java)) + val method = ValueClassController::valueClassWithNullable.javaMethod!! + val result = invoke(ValueClassController(), method, null) + assertHandlerResultValue(result, "null") + } + @Test fun propertyAccessor() { this.resolvers.add(stubResolver(null, String::class.java)) @@ -321,9 +336,13 @@ class InvocableHandlerMethodKotlinTests { fun unit() { } - fun nullable(): String? { + fun nullableReturnValue(): String? { return null } + + fun nullableParameter(value: String?): String? { + return value + } } class ValueClassController { @@ -337,6 +356,9 @@ class InvocableHandlerMethodKotlinTests { fun valueClassWithInit(valueClass: ValueClassWithInit) = valueClass + fun valueClassWithNullable(limit: LongValueClass?) = + "${limit?.value}" + } class PropertyAccessorController {