diff --git a/spring-web/src/main/java/org/springframework/web/server/adapter/HttpWebHandlerAdapter.java b/spring-web/src/main/java/org/springframework/web/server/adapter/HttpWebHandlerAdapter.java index 2333544a8bb2..a3147d8edaeb 100644 --- a/spring-web/src/main/java/org/springframework/web/server/adapter/HttpWebHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/web/server/adapter/HttpWebHandlerAdapter.java @@ -17,14 +17,16 @@ package org.springframework.web.server.adapter; import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.reactivestreams.Publisher; +import reactor.core.observability.DefaultSignalListener; import reactor.core.publisher.Mono; +import reactor.util.context.Context; import org.springframework.context.ApplicationContext; import org.springframework.core.NestedExceptionUtils; @@ -302,7 +304,9 @@ public Mono handle(ServerHttpRequest request, ServerHttpResponse response) ServerRequestObservationContext.CURRENT_OBSERVATION_CONTEXT_ATTRIBUTE, observationContext); return getDelegate().handle(exchange) - .transformDeferred(call -> transform(exchange, observationContext, call)) + .doOnSuccess(aVoid -> logResponse(exchange)) + .onErrorResume(ex -> handleUnresolvedError(exchange, observationContext, ex)) + .tap(() -> new ObservationSignalListener(observationContext)) .then(exchange.cleanupMultipart()) .then(Mono.defer(response::setComplete)); } @@ -324,42 +328,6 @@ protected String formatRequest(ServerHttpRequest request) { return "HTTP " + request.getMethod() + " \"" + request.getPath() + query + "\""; } - private Publisher transform(ServerWebExchange exchange, ServerRequestObservationContext observationContext, Mono call) { - Observation observation = ServerHttpObservationDocumentation.HTTP_REACTIVE_SERVER_REQUESTS.observation( - this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry); - observation.start(); - return call - .doOnSuccess(aVoid -> { - logResponse(exchange); - stopObservation(observation, exchange); - }) - .onErrorResume(ex -> handleUnresolvedError(exchange, observationContext, ex)) - .doOnCancel(() -> cancelObservation(observationContext, observation)) - .contextWrite(context -> context.put(ObservationThreadLocalAccessor.KEY, observation)); - } - - private void stopObservation(Observation observation, ServerWebExchange exchange) { - Throwable throwable = exchange.getAttribute(ExceptionHandlingWebHandler.HANDLED_WEB_EXCEPTION); - if (throwable != null) { - observation.error(throwable); - } - ServerHttpResponse response = exchange.getResponse(); - if (response.isCommitted()) { - observation.stop(); - } - else { - response.beforeCommit(() -> { - observation.stop(); - return Mono.empty(); - }); - } - } - - private void cancelObservation(ServerRequestObservationContext observationContext, Observation observation) { - observationContext.setConnectionAborted(true); - observation.stop(); - } - private void logResponse(ServerWebExchange exchange) { LogFormatUtils.traceDebug(logger, traceOn -> { HttpStatusCode status = exchange.getResponse().getStatusCode(); @@ -415,4 +383,66 @@ private boolean isDisconnectedClientError(Throwable ex) { return DISCONNECTED_CLIENT_EXCEPTIONS.contains(ex.getClass().getSimpleName()); } + private final class ObservationSignalListener extends DefaultSignalListener { + + private final ServerRequestObservationContext observationContext; + + private final Observation observation; + + private AtomicBoolean observationRecorded = new AtomicBoolean(); + + public ObservationSignalListener(ServerRequestObservationContext observationContext) { + this.observationContext = observationContext; + this.observation = ServerHttpObservationDocumentation.HTTP_REACTIVE_SERVER_REQUESTS.observation(observationConvention, + DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry); + } + + @Override + public void doOnSubscription() throws Throwable { + this.observation.start(); + } + + @Override + public Context addToContext(Context originalContext) { + return originalContext.put(ObservationThreadLocalAccessor.KEY, this.observation); + } + + @Override + public void doOnCancel() throws Throwable { + if (this.observationRecorded.compareAndSet(false, true)) { + this.observationContext.setConnectionAborted(true); + this.observation.stop(); + } + } + + @Override + public void doOnComplete() throws Throwable { + if (this.observationRecorded.compareAndSet(false, true)) { + ServerHttpResponse response = this.observationContext.getResponse(); + Throwable throwable = (Throwable) this.observationContext.getAttributes() + .get(ExceptionHandlingWebHandler.HANDLED_WEB_EXCEPTION); + if (throwable != null) { + this.observation.error(throwable); + } + if (response.isCommitted()) { + this.observation.stop(); + } + else { + response.beforeCommit(() -> { + this.observation.stop(); + return Mono.empty(); + }); + } + } + } + + @Override + public void doOnError(Throwable error) throws Throwable { + if (this.observationRecorded.compareAndSet(false, true)) { + this.observationContext.setError(error); + this.observation.stop(); + } + } + } + }