Skip to content

Commit

Permalink
Context propagation with automatic ThreadLocals restoration (#3335)
Browse files Browse the repository at this point in the history
Adds a few primitives necessary for propagating reactor `Context` to
`ThreadLocal`s using the context-propagation library. It adds the following
primitives:

* dedicated `Scheduler` task wrapping
* dedicated global `Queue` wrapper
* dedicated alternative to `FluxContextWrite` and `MonoContextWrite` operators

All of these use the context restoration mechanism from context-propagation.
  • Loading branch information
chemicL authored Feb 14, 2023
1 parent d3b4956 commit 5d3a6b5
Show file tree
Hide file tree
Showing 11 changed files with 1,021 additions and 14 deletions.
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ logback = "ch.qos.logback:logback-classic:1.2.11"
micrometer-bom = { module = "io.micrometer:micrometer-bom", version.ref = "micrometer" }
micrometer-commons = { module = "io.micrometer:micrometer-commons" }
micrometer-core = { module = "io.micrometer:micrometer-core" }
micrometer-contextPropagation = "io.micrometer:context-propagation:1.0.0"
micrometer-contextPropagation = "io.micrometer:context-propagation:1.0.2"
micrometer-docsGenerator = { module = "io.micrometer:micrometer-docs-generator", version = "1.0.1"}
micrometer-observation-test = { module = "io.micrometer:micrometer-observation-test" }
micrometer-tracing-test = "io.micrometer:micrometer-tracing-integration-test:1.0.1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,31 @@

package reactor.core.publisher;

import java.util.AbstractQueue;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Queue;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;

import io.micrometer.context.ContextAccessor;
import io.micrometer.context.ContextRegistry;
import io.micrometer.context.ContextSnapshot;

import io.micrometer.context.ThreadLocalAccessor;
import reactor.core.observability.SignalListener;
import reactor.util.Logger;
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;
import reactor.util.context.Context;
import reactor.util.context.ContextView;

import static reactor.core.Fuseable.QueueSubscription.NOT_SUPPORTED_MESSAGE;

/**
* Utility private class to detect if the <a href="https://github.com/micrometer-metrics/context-propagation">context-propagation library</a> is on the classpath and to offer
* ContextSnapshot support to {@link Flux} and {@link Mono}.
Expand All @@ -42,11 +52,13 @@ final class ContextPropagation {
static final Logger LOGGER;

static final boolean isContextPropagationAvailable;
static boolean propagateContextToThreadLocals = false;

static final Predicate<Object> PREDICATE_TRUE = v -> true;
static final Function<Context, Context> NO_OP = c -> c;
static final Function<Context, Context> WITH_GLOBAL_REGISTRY_NO_PREDICATE;


static {
LOGGER = Loggers.getLogger(ContextPropagation.class);

Expand Down Expand Up @@ -83,6 +95,17 @@ static boolean isContextPropagationAvailable() {
return isContextPropagationAvailable;
}

static boolean shouldPropagateContextToThreadLocals() {
return isContextPropagationAvailable && propagateContextToThreadLocals;
}

public static Function<Runnable, Runnable> scopePassingOnScheduleHook() {
return delegate -> {
ContextSnapshot contextSnapshot = ContextSnapshot.captureAll();
return contextSnapshot.wrap(delegate);
};
}

/**
* Create a support function that takes a snapshot of thread locals and merges them with the
* provided {@link Context}, resulting in a new {@link Context} which includes entries
Expand Down Expand Up @@ -126,7 +149,7 @@ static Function<Context, Context> contextCapture(Predicate<Object> captureKeyPre
}

static <T, R> BiConsumer<T, SynchronousSink<R>> contextRestoreForHandle(BiConsumer<T, SynchronousSink<R>> handler, Supplier<Context> contextSupplier) {
if (!ContextPropagation.isContextPropagationAvailable()) {
if (propagateContextToThreadLocals || !ContextPropagation.isContextPropagationAvailable()) {
return handler;
}
final Context ctx = contextSupplier.get();
Expand All @@ -141,7 +164,7 @@ static <T, R> BiConsumer<T, SynchronousSink<R>> contextRestoreForHandle(BiConsum
}

static <T> SignalListener<T> contextRestoreForTap(final SignalListener<T> original, Supplier<Context> contextSupplier) {
if (!ContextPropagation.isContextPropagationAvailable()) {
if (propagateContextToThreadLocals || !ContextPropagation.isContextPropagationAvailable()) {
return original;
}
final Context ctx = contextSupplier.get();
Expand Down Expand Up @@ -281,6 +304,151 @@ public Context addToContext(Context originalContext) {
}
}

static final class ContextQueue<T> extends AbstractQueue<T> {

final Queue<Envelope<T>> envelopeQueue;

boolean cleanOnNull;
boolean hasPrevious = false;

Thread lastReader;
ContextSnapshot.Scope scope;

@SuppressWarnings({"unchecked", "rawtypes"})
ContextQueue(Queue<?> queue) {
this.envelopeQueue = (Queue) queue;
}

@Override
public int size() {
return envelopeQueue.size();
}

@Override
public boolean offer(T o) {
ContextSnapshot contextSnapshot = ContextSnapshot.captureAll();
return envelopeQueue.offer(new Envelope<>(o, contextSnapshot));
}

@Override
public T poll() {
Envelope<T> envelope = envelopeQueue.poll();
if (envelope == null) {
if (cleanOnNull && scope != null) {
// clear thread-locals if they were just restored
scope.close();
}
cleanOnNull = true;
lastReader = Thread.currentThread();
hasPrevious = false;
return null;
}


restoreTheContext(envelope);
hasPrevious = true;
return envelope.body;
}

private void restoreTheContext(Envelope<T> envelope) {
ContextSnapshot contextSnapshot = envelope.contextSnapshot;
// tries to read existing Thread for existing ThreadLocals
ContextSnapshot currentContextSnapshot = ContextSnapshot.captureAll();
if (!contextSnapshot.equals(currentContextSnapshot)) {
if (!hasPrevious || !Thread.currentThread().equals(this.lastReader)) {
// means context was restored form the envelope,
// thus it has to be cleared
cleanOnNull = true;
lastReader = Thread.currentThread();
}
scope = contextSnapshot.setThreadLocals();
}
else if (!hasPrevious || !Thread.currentThread().equals(this.lastReader)) {
// means same context was already available, no need to clean anything
cleanOnNull = false;
lastReader = Thread.currentThread();
}
}

@Override
@Nullable
public T peek() {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public boolean add(@Nullable T t) {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public T remove() {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public T element() {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public boolean contains(@Nullable Object o) {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public Iterator<T> iterator() {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public Object[] toArray() {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public <T1> T1[] toArray(T1[] a) {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public boolean remove(@Nullable Object o) {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public boolean containsAll(Collection<?> c) {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public boolean addAll(Collection<? extends T> c) {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public boolean removeAll(Collection<?> c) {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}

@Override
public boolean retainAll(Collection<?> c) {
throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE);
}
}

static class Envelope<T> {

final T body;
final ContextSnapshot contextSnapshot;

Envelope(T body, ContextSnapshot contextSnapshot) {
this.body = body;
this.contextSnapshot = contextSnapshot;
}

}

static final class ContextCaptureNoPredicate implements Function<Context, Context> {
final ContextRegistry globalRegistry;

Expand All @@ -293,4 +461,76 @@ public Context apply(Context context) {
.updateContext(context);
}
}

/*
* Temporary methods not present in context-propagation library that allow
* clearing ThreadLocals not present in Reactor Context. Once context-propagation
* library adds the ability to do this, they can be removed from reactor-core.
*/

@SuppressWarnings("unchecked")
static <C> ContextSnapshot.Scope setThreadLocals(Object context) {
ContextRegistry registry = ContextRegistry.getInstance();
ContextAccessor<?, ?> contextAccessor = registry.getContextAccessorForRead(context);
Map<Object, Object> previousValues = null;
for (ThreadLocalAccessor<?> threadLocalAccessor : registry.getThreadLocalAccessors()) {
Object key = threadLocalAccessor.key();
Object value = ((ContextAccessor<C, ?>) contextAccessor).readValue((C) context, key);
previousValues = setThreadLocal(key, value, threadLocalAccessor, previousValues);
}
return ReactorScopeImpl.from(previousValues, registry);
}

@SuppressWarnings("unchecked")
private static <V> Map<Object, Object> setThreadLocal(Object key, @Nullable V value,
ThreadLocalAccessor<?> accessor, @Nullable Map<Object, Object> previousValues) {

previousValues = (previousValues != null ? previousValues : new HashMap<>());
previousValues.put(key, accessor.getValue());
if (value != null) {
((ThreadLocalAccessor<V>) accessor).setValue(value);
}
else {
accessor.reset();
}
return previousValues;
}

private static class ReactorScopeImpl implements ContextSnapshot.Scope {

private final Map<Object, Object> previousValues;

private final ContextRegistry contextRegistry;

private ReactorScopeImpl(Map<Object, Object> previousValues,
ContextRegistry contextRegistry) {
this.previousValues = previousValues;
this.contextRegistry = contextRegistry;
}

@Override
public void close() {
for (ThreadLocalAccessor<?> accessor : this.contextRegistry.getThreadLocalAccessors()) {
if (this.previousValues.containsKey(accessor.key())) {
Object previousValue = this.previousValues.get(accessor.key());
resetThreadLocalValue(accessor, previousValue);
}
}
}

@SuppressWarnings("unchecked")
private <V> void resetThreadLocalValue(ThreadLocalAccessor<?> accessor, @Nullable V previousValue) {
if (previousValue != null) {
((ThreadLocalAccessor<V>) accessor).restore(previousValue);
}
else {
accessor.reset();
}
}

public static ContextSnapshot.Scope from(@Nullable Map<Object, Object> previousValues, ContextRegistry registry) {
return (previousValues != null ? new ReactorScopeImpl(previousValues, registry) : () -> {
});
}
}
}
13 changes: 12 additions & 1 deletion reactor-core/src/main/java/reactor/core/publisher/Flux.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2016-2022 VMware Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2016-2023 VMware Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,6 +44,7 @@
import java.util.function.Supplier;
import java.util.logging.Level;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import io.micrometer.core.instrument.MeterRegistry;
Expand Down Expand Up @@ -4165,6 +4166,11 @@ public final Flux<T> contextCapture() {
if (!ContextPropagation.isContextPropagationAvailable()) {
return this;
}
if (ContextPropagation.propagateContextToThreadLocals) {
return onAssembly(new FluxContextWriteRestoringThreadLocals<>(
this, ContextPropagation.contextCapture()
));
}
return onAssembly(new FluxContextWrite<>(this, ContextPropagation.contextCapture()));
}

Expand Down Expand Up @@ -4211,6 +4217,11 @@ public final Flux<T> contextWrite(ContextView contextToAppend) {
* @see Context
*/
public final Flux<T> contextWrite(Function<Context, Context> contextModifier) {
if (ContextPropagation.shouldPropagateContextToThreadLocals()) {
return onAssembly(new FluxContextWriteRestoringThreadLocals<>(
this, contextModifier
));
}
return onAssembly(new FluxContextWrite<>(this, contextModifier));
}

Expand Down
Loading

0 comments on commit 5d3a6b5

Please sign in to comment.