From 6db2b0dcd097e85ace3960d38deee32145afa751 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 24 Mar 2023 17:32:11 -0600 Subject: [PATCH] Align Filter Chain Observability Lineage Closes gh-12849 --- .../ObservationWebFilterChainDecorator.java | 161 ++++++++++++++++-- ...servationWebFilterChainDecoratorTests.java | 150 ++++++++++++++++ 2 files changed, 299 insertions(+), 12 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java index 9724fe99f2f..12c78b9faf3 100644 --- a/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java +++ b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationConvention; @@ -265,20 +266,23 @@ class SimpleAroundWebFilterObservation implements AroundWebFilterObservation { } @Override - public void start() { + public Observation start() { if (this.currentObservation.compareAndSet(ObservationReference.NOOP, this.before)) { this.before.start(); - return; + return this.before.observation; } if (this.currentObservation.compareAndSet(this.before, this.after)) { this.before.stop(); this.after.start(); + return this.after.observation; } + return Observation.NOOP; } @Override - public void error(Throwable ex) { + public Observation error(Throwable ex) { this.currentObservation.get().error(ex); + return this.currentObservation.get().observation; } @Override @@ -286,6 +290,46 @@ public void stop() { this.currentObservation.get().stop(); } + @Override + public Observation contextualName(String contextualName) { + return this.currentObservation.get().observation.contextualName(contextualName); + } + + @Override + public Observation parentObservation(Observation parentObservation) { + return this.currentObservation.get().observation.parentObservation(parentObservation); + } + + @Override + public Observation lowCardinalityKeyValue(KeyValue keyValue) { + return this.currentObservation.get().observation.lowCardinalityKeyValue(keyValue); + } + + @Override + public Observation highCardinalityKeyValue(KeyValue keyValue) { + return this.currentObservation.get().observation.highCardinalityKeyValue(keyValue); + } + + @Override + public Observation observationConvention(ObservationConvention observationConvention) { + return this.currentObservation.get().observation.observationConvention(observationConvention); + } + + @Override + public Observation event(Event event) { + return this.currentObservation.get().observation.event(event); + } + + @Override + public Context getContext() { + return this.currentObservation.get().observation.getContext(); + } + + @Override + public Scope openScope() { + return this.currentObservation.get().observation.openScope(); + } + @Override public WebFilterChain wrap(WebFilterChain chain) { return (exchange) -> { @@ -313,7 +357,8 @@ public WebFilter wrap(WebFilter filter) { .doOnError((t) -> { error(t); stop(); - }); + }) + .contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, this)); // @formatter:on }; } @@ -328,6 +373,11 @@ public Observation after() { return this.after.observation; } + @Override + public String toString() { + return this.currentObservation.get().observation.toString(); + } + private static final class ObservationReference { private static final ObservationReference NOOP = new ObservationReference(Observation.NOOP); @@ -364,7 +414,7 @@ private void stop() { } - interface WebFilterObservation { + interface WebFilterObservation extends Observation { WebFilterObservation NOOP = new WebFilterObservation() { }; @@ -376,13 +426,59 @@ static WebFilterObservation create(Observation observation) { return new SimpleWebFilterObservation(observation); } - default void start() { + @Override + default Observation contextualName(String contextualName) { + return Observation.NOOP; + } + + @Override + default Observation parentObservation(Observation parentObservation) { + return Observation.NOOP; + } + + @Override + default Observation lowCardinalityKeyValue(KeyValue keyValue) { + return Observation.NOOP; } - default void error(Throwable ex) { + @Override + default Observation highCardinalityKeyValue(KeyValue keyValue) { + return Observation.NOOP; } + @Override + default Observation observationConvention(ObservationConvention observationConvention) { + return Observation.NOOP; + } + + @Override + default Observation error(Throwable error) { + return Observation.NOOP; + } + + @Override + default Observation event(Event event) { + return Observation.NOOP; + } + + @Override + default Observation start() { + return Observation.NOOP; + } + + @Override + default Context getContext() { + return new Observation.Context(); + } + + @Override default void stop() { + + } + + @Override + default Scope openScope() { + return Scope.NOOP; } default WebFilter wrap(WebFilter filter) { @@ -402,13 +498,13 @@ class SimpleWebFilterObservation implements WebFilterObservation { } @Override - public void start() { - this.observation.start(); + public Observation start() { + return this.observation.start(); } @Override - public void error(Throwable ex) { - this.observation.error(ex); + public Observation error(Throwable ex) { + return this.observation.error(ex); } @Override @@ -416,6 +512,46 @@ public void stop() { this.observation.stop(); } + @Override + public Observation contextualName(String contextualName) { + return this.observation.contextualName(contextualName); + } + + @Override + public Observation parentObservation(Observation parentObservation) { + return this.observation.parentObservation(parentObservation); + } + + @Override + public Observation lowCardinalityKeyValue(KeyValue keyValue) { + return this.observation.lowCardinalityKeyValue(keyValue); + } + + @Override + public Observation highCardinalityKeyValue(KeyValue keyValue) { + return this.observation.highCardinalityKeyValue(keyValue); + } + + @Override + public Observation observationConvention(ObservationConvention observationConvention) { + return this.observation.observationConvention(observationConvention); + } + + @Override + public Observation event(Event event) { + return this.observation.event(event); + } + + @Override + public Context getContext() { + return this.observation.getContext(); + } + + @Override + public Scope openScope() { + return this.observation.openScope(); + } + @Override public WebFilter wrap(WebFilter filter) { if (this.observation.isNoop()) { @@ -442,7 +578,8 @@ public WebFilterChain wrap(WebFilterChain chain) { .doOnCancel(this.observation::stop).doOnError((t) -> { this.observation.error(t); this.observation.stop(); - }); + }).contextWrite( + (context) -> context.put(ObservationThreadLocalAccessor.KEY, this.observation)); }; } diff --git a/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java b/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java index 08aba40d9f8..d4a6e702ad5 100644 --- a/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java +++ b/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java @@ -16,15 +16,22 @@ package org.springframework.security.web.server; +import java.util.ArrayList; +import java.util.List; + +import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationHandler; import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -64,4 +71,147 @@ void decorateWhenNoopThenDoesNotObserve() { verifyNoInteractions(handler); } + // gh-12849 + @Test + void decorateWhenCustomAfterFilterThenObserves() { + AccumulatingObservationHandler handler = new AccumulatingObservationHandler(); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry); + WebFilter mock = mock(WebFilter.class); + given(mock.filter(any(), any())).willReturn(Mono.empty()); + WebFilterChain chain = mock(WebFilterChain.class); + given(chain.filter(any())).willReturn(Mono.empty()); + WebFilterChain decorated = decorator.decorate(chain, + List.of((e, c) -> c.filter(e).then(Mono.deferContextual((context) -> { + Observation parentObservation = context.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + Observation observation = Observation.createNotStarted("custom", registry) + .parentObservation(parentObservation).contextualName("custom").start(); + return Mono.just("3").doOnSuccess((v) -> observation.stop()).doOnCancel(observation::stop) + .doOnError((t) -> { + observation.error(t); + observation.stop(); + }).then(Mono.empty()); + })))); + Observation http = Observation.start("http", registry).contextualName("http"); + try { + decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build())) + .contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http)).block(); + } + finally { + http.stop(); + } + handler.assertSpanStart(0, "http", null); + handler.assertSpanStart(1, "spring.security.filterchains", "http"); + handler.assertSpanStop(2, "security filterchain before"); + handler.assertSpanStart(3, "secured request", "security filterchain before"); + handler.assertSpanStop(4, "secured request"); + handler.assertSpanStart(5, "spring.security.filterchains", "http"); + handler.assertSpanStart(6, "custom", "spring.security.filterchains"); + handler.assertSpanStop(7, "custom"); + handler.assertSpanStop(8, "security filterchain after"); + handler.assertSpanStop(9, "http"); + } + + static class AccumulatingObservationHandler implements ObservationHandler { + + List contexts = new ArrayList<>(); + + @Override + public boolean supportsContext(Observation.Context context) { + return true; + } + + @Override + public void onStart(Observation.Context context) { + this.contexts.add(new Event("start", context)); + } + + @Override + public void onError(Observation.Context context) { + this.contexts.add(new Event("error", context)); + } + + @Override + public void onEvent(Observation.Event event, Observation.Context context) { + this.contexts.add(new Event("event", context)); + } + + @Override + public void onScopeOpened(Observation.Context context) { + this.contexts.add(new Event("opened", context)); + } + + @Override + public void onScopeClosed(Observation.Context context) { + this.contexts.add(new Event("closed", context)); + } + + @Override + public void onScopeReset(Observation.Context context) { + this.contexts.add(new Event("reset", context)); + } + + @Override + public void onStop(Observation.Context context) { + this.contexts.add(new Event("stop", context)); + } + + private void assertSpanStart(int index, String name, String parentName) { + Event event = this.contexts.get(index); + assertThat(event.event).isEqualTo("start"); + if (event.contextualName == null) { + assertThat(event.name).isEqualTo(name); + } + else { + assertThat(event.contextualName).isEqualTo(name); + } + if (parentName == null) { + return; + } + if (event.parentContextualName == null) { + assertThat(event.parentName).isEqualTo(parentName); + } + else { + assertThat(event.parentContextualName).isEqualTo(parentName); + } + } + + private void assertSpanStop(int index, String name) { + Event event = this.contexts.get(index); + assertThat(event.event).isEqualTo("stop"); + if (event.contextualName == null) { + assertThat(event.name).isEqualTo(name); + } + else { + assertThat(event.contextualName).isEqualTo(name); + } + } + + static class Event { + + String event; + + String name; + + String contextualName; + + String parentName; + + String parentContextualName; + + Event(String event, Observation.Context context) { + this.event = event; + this.name = context.getName(); + this.contextualName = context.getContextualName(); + if (context.getParentObservation() != null) { + this.parentName = context.getParentObservation().getContextView().getName(); + this.parentContextualName = context.getParentObservation().getContextView().getContextualName(); + } + } + + } + + } + }