diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/ServerHttpObservationFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/ServerHttpObservationFilter.java index c290926307dd..6f40fa81ca60 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/reactive/ServerHttpObservationFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/ServerHttpObservationFilter.java @@ -121,16 +121,17 @@ public ObservationSignalListener(ServerRequestObservationContext observationCont DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry); } - @Override - public void doOnSubscription() throws Throwable { - this.observation.start(); - } @Override public Context addToContext(Context originalContext) { return originalContext.put(ObservationThreadLocalAccessor.KEY, this.observation); } + @Override + public void doFirst() throws Throwable { + this.observation.start(); + } + @Override public void doOnCancel() throws Throwable { if (this.observationRecorded.compareAndSet(false, true)) { @@ -142,16 +143,7 @@ public void doOnCancel() throws Throwable { @Override public void doOnComplete() throws Throwable { if (this.observationRecorded.compareAndSet(false, true)) { - ServerHttpResponse response = this.observationContext.getResponse(); - if (response.isCommitted()) { - this.observation.stop(); - } - else { - response.beforeCommit(() -> { - this.observation.stop(); - return Mono.empty(); - }); - } + doOnTerminate(this.observationContext); } } @@ -162,8 +154,21 @@ public void doOnError(Throwable error) throws Throwable { this.observationContext.setConnectionAborted(true); } this.observationContext.setError(error); + doOnTerminate(this.observationContext); + } + } + + private void doOnTerminate(ServerRequestObservationContext context) { + ServerHttpResponse response = context.getResponse(); + if (response.isCommitted()) { this.observation.stop(); } + else { + response.beforeCommit(() -> { + this.observation.stop(); + return Mono.empty(); + }); + } } } diff --git a/spring-web/src/test/java/org/springframework/web/filter/reactive/ServerHttpObservationFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/reactive/ServerHttpObservationFilterTests.java index 4bef48c1149d..9a68eb862e45 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/reactive/ServerHttpObservationFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/reactive/ServerHttpObservationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import java.util.Optional; +import io.micrometer.observation.Observation; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; @@ -27,6 +28,7 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.http.server.reactive.observation.ServerRequestObservationContext; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilterChain; @@ -65,7 +67,10 @@ void filterShouldAddNewObservationToReactorContext() { ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/test/resource")); exchange.getResponse().setRawStatusCode(200); WebFilterChain filterChain = webExchange -> Mono.deferContextual(contextView -> { - assertThat(contextView.getOrEmpty(ObservationThreadLocalAccessor.KEY)).isPresent(); + Observation observation = contextView.get(ObservationThreadLocalAccessor.KEY); + assertThat(observation).isNotNull(); + // check that the observation was started + assertThat(observation.getContext().getLowCardinalityKeyValue("outcome")).isNotNull(); return Mono.empty(); }); this.filter.filter(exchange, filterChain).block(); @@ -99,6 +104,25 @@ void filterShouldRecordObservationWhenCancelled() { assertThatHttpObservation().hasLowCardinalityKeyValue("outcome", "UNKNOWN"); } + @Test + void filterShouldStopObservationOnResponseCommit() { + ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/test/resource")); + WebFilterChain filterChain = createFilterChain(filterExchange -> { + throw new IllegalArgumentException("server error"); + }); + StepVerifier.create(this.filter.filter(exchange, filterChain).doOnError(throwable -> { + ServerHttpResponse response = exchange.getResponse(); + response.setRawStatusCode(500); + response.setComplete().block(); + })) + .expectError(IllegalArgumentException.class) + .verify(); + Optional observationContext = ServerHttpObservationFilter.findObservationContext(exchange); + assertThat(observationContext.get().getError()).isInstanceOf(IllegalArgumentException.class); + assertThatHttpObservation().hasLowCardinalityKeyValue("outcome", "SERVER_ERROR"); + } + + private WebFilterChain createFilterChain(ThrowingConsumer exchangeConsumer) { return filterExchange -> { try {