Skip to content

Commit

Permalink
Propagate CoroutineContext in CoWebFilter
Browse files Browse the repository at this point in the history
This provides an elegant and dynamic way to customize the
CoroutineContext in WebFlux with the annotation programming
model.

Closes spring-projectsgh-27522
  • Loading branch information
sdeleuze committed Sep 7, 2023
1 parent 9d768a8 commit b0aa004
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.web.server

import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.reactor.awaitSingleOrNull
import kotlinx.coroutines.reactor.mono
import reactor.core.publisher.Mono
Expand All @@ -26,6 +28,7 @@ import reactor.core.publisher.Mono
* using coroutines.
*
* @author Arjen Poutsma
* @author Sebastien Deleuze
* @since 6.0.5
*/
abstract class CoWebFilter : WebFilter {
Expand All @@ -34,6 +37,7 @@ abstract class CoWebFilter : WebFilter {
return mono(Dispatchers.Unconfined) {
filter(exchange, object : CoWebFilterChain {
override suspend fun filter(exchange: ServerWebExchange) {
exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key)
chain.filter(exchange).awaitSingleOrNull()
}
})}.then()
Expand All @@ -47,6 +51,12 @@ abstract class CoWebFilter : WebFilter {
*/
protected abstract suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain)

companion object {

@JvmField
val COROUTINE_CONTEXT_ATTRIBUTE = CoWebFilter::class.java.getName() + ".context"
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.web.server

import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.mockito.BDDMockito.given
Expand All @@ -24,9 +26,11 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe
import org.springframework.web.testfixture.server.MockServerWebExchange
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
import kotlin.coroutines.CoroutineContext

/**
* @author Arjen Poutsma
* @author Sebastien Deleuze
*/
class CoWebFilterTests {

Expand All @@ -45,6 +49,26 @@ class CoWebFilterTests {

assertThat(exchange.attributes["foo"]).isEqualTo("bar")
}

@Test
fun filterWithContext() {
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))

val chain = Mockito.mock(WebFilterChain::class.java)
given(chain.filter(exchange)).willReturn(Mono.empty())

val filter = MyCoWebFilterWithContext()
val result = filter.filter(exchange, chain)

StepVerifier.create(result).verifyComplete()

val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext
assertThat(context).isNotNull()
val coroutineName = context[CoroutineName.Key] as CoroutineName
assertThat(coroutineName).isNotNull()
assertThat(coroutineName.name).isEqualTo("foo")
}

}


Expand All @@ -53,4 +77,12 @@ private class MyCoWebFilter : CoWebFilter() {
exchange.attributes["foo"] = "bar"
chain.filter(exchange)
}
}
}

private class MyCoWebFilterWithContext : CoWebFilter() {
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
withContext(CoroutineName("foo")) {
chain.filter(exchange)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Objects;
import java.util.stream.Stream;

import kotlin.coroutines.CoroutineContext;
import kotlin.reflect.KFunction;
import kotlin.reflect.KParameter;
import kotlin.reflect.jvm.KCallablesJvm;
Expand All @@ -48,6 +49,7 @@
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.server.CoWebFilter;
import org.springframework.web.server.ServerWebExchange;

/**
Expand Down Expand Up @@ -152,7 +154,7 @@ public void setMethodValidator(@Nullable MethodValidator methodValidator) {
* @param providedArgs optional list of argument values to match by type
* @return a Mono with a {@link HandlerResult}
*/
@SuppressWarnings({"KotlinInternalInJava", "unchecked"})
@SuppressWarnings("unchecked")
public Mono<HandlerResult> invoke(
ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) {

Expand All @@ -167,12 +169,7 @@ public Mono<HandlerResult> invoke(
boolean isSuspendingFunction = KotlinDetector.isSuspendingFunction(method);
try {
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
if (isSuspendingFunction) {
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
}
else {
value = KotlinDelegate.invokeFunction(method, getBean(), args);
}
value = KotlinDelegate.invokeFunction(method, getBean(), args, isSuspendingFunction, exchange);
}
else {
value = method.invoke(getBean(), args);
Expand Down Expand Up @@ -297,25 +294,38 @@ private static class KotlinDelegate {

@Nullable
@SuppressWarnings("deprecation")
public static Object invokeFunction(Method method, Object target, Object[] args) {
KFunction<?> function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method));
if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) {
KCallablesJvm.setAccessible(function, true);
public static Object invokeFunction(Method method, Object target, Object[] args, boolean isSuspendingFunction,
ServerWebExchange exchange) {

if (isSuspendingFunction) {
Object coroutineContext = exchange.getAttribute(CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE);
if (coroutineContext == null) {
return CoroutinesUtils.invokeSuspendingFunction(method, target, args);
}
else {
return CoroutinesUtils.invokeSuspendingFunction((CoroutineContext) coroutineContext, method, target, args);
}
}
Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1);
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
if (!parameter.isOptional() || args[index] != null) {
argMap.put(parameter, args[index]);
else {
KFunction<?> function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method));
if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) {
KCallablesJvm.setAccessible(function, true);
}
Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1);
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
if (!parameter.isOptional() || args[index] != null) {
argMap.put(parameter, args[index]);
}
index++;
}
index++;
}
}
return function.callBy(argMap);
}
return function.callBy(argMap);
}
}

Expand Down

0 comments on commit b0aa004

Please sign in to comment.