diff --git a/auto-api/src/main/java/io/opentelemetry/instrumentation/auto/api/WeakMap.java b/auto-api/src/main/java/io/opentelemetry/instrumentation/auto/api/WeakMap.java index e68dda5c5cf2..c9315c547adf 100644 --- a/auto-api/src/main/java/io/opentelemetry/instrumentation/auto/api/WeakMap.java +++ b/auto-api/src/main/java/io/opentelemetry/instrumentation/auto/api/WeakMap.java @@ -37,6 +37,8 @@ public interface WeakMap { V computeIfAbsent(K key, ValueSupplier supplier); + V remove(K key); + class Provider { private static final Logger log = LoggerFactory.getLogger(Provider.class); @@ -145,6 +147,11 @@ public V computeIfAbsent(K key, ValueSupplier supplier) } } + @Override + public V remove(K key) { + return map.remove(key); + } + @Override public String toString() { return map.toString(); diff --git a/auto-api/src/test/groovy/io/opentelemetry/instrumentation/auto/api/WeakMapTest.groovy b/auto-api/src/test/groovy/io/opentelemetry/instrumentation/auto/api/WeakMapTest.groovy index 441ef5784e92..ca8d972e661c 100644 --- a/auto-api/src/test/groovy/io/opentelemetry/instrumentation/auto/api/WeakMapTest.groovy +++ b/auto-api/src/test/groovy/io/opentelemetry/instrumentation/auto/api/WeakMapTest.groovy @@ -55,6 +55,19 @@ class WeakMapTest extends Specification { supplier.counter == 2 } + def "remove a value"() { + given: + weakMap.put('key', 42) + + when: + def removed = weakMap.remove('key') + + then: + removed == 42 + weakMap.get('key') == null + weakMap.size() == 0 + } + class CounterSupplier implements WeakMap.ValueSupplier { def counter = 0 diff --git a/javaagent-tooling/src/main/java/io/opentelemetry/javaagent/tooling/WeakMapSuppliers.java b/javaagent-tooling/src/main/java/io/opentelemetry/javaagent/tooling/WeakMapSuppliers.java index fc79d9b2a76c..656438a478b3 100644 --- a/javaagent-tooling/src/main/java/io/opentelemetry/javaagent/tooling/WeakMapSuppliers.java +++ b/javaagent-tooling/src/main/java/io/opentelemetry/javaagent/tooling/WeakMapSuppliers.java @@ -121,6 +121,11 @@ public V computeIfAbsent(K key, ValueSupplier supplier) } } } + + @Override + public V remove(K key) { + return map.remove(key); + } } static class Inline implements WeakMap.Implementation { diff --git a/javaagent-tooling/src/main/java/io/opentelemetry/javaagent/tooling/context/FieldBackedProvider.java b/javaagent-tooling/src/main/java/io/opentelemetry/javaagent/tooling/context/FieldBackedProvider.java index 2553e767a9d8..1e69f1d9707d 100644 --- a/javaagent-tooling/src/main/java/io/opentelemetry/javaagent/tooling/context/FieldBackedProvider.java +++ b/javaagent-tooling/src/main/java/io/opentelemetry/javaagent/tooling/context/FieldBackedProvider.java @@ -936,7 +936,11 @@ private Object mapGet(Object key) { } private void mapPut(Object key, Object value) { - map.put(key, value); + if (value == null) { + map.remove(key); + } else { + map.put(key, value); + } } private Object mapSynchronizeInstance(Object key) { diff --git a/testing-common/src/test/groovy/context/FieldBackedProviderTest.groovy b/testing-common/src/test/groovy/context/FieldBackedProviderTest.groovy index bade24167ab3..a8803bae2a8c 100644 --- a/testing-common/src/test/groovy/context/FieldBackedProviderTest.groovy +++ b/testing-common/src/test/groovy/context/FieldBackedProviderTest.groovy @@ -126,6 +126,22 @@ class FieldBackedProviderTest extends AgentTestRunner { new UntransformableKeyClass() | _ } + def "remove test"() { + given: + instance1.putContextCount(10) + + when: + instance1.removeContextCount() + + then: + instance1.getContextCount() == 0 + + where: + instance1 | _ + new KeyClass() | _ + new UntransformableKeyClass() | _ + } + def "works with cglib enhanced instances which duplicates context getter and setter methods"() { setup: Enhancer enhancer = new Enhancer() diff --git a/testing-common/src/test/java/context/ContextTestInstrumentation.java b/testing-common/src/test/java/context/ContextTestInstrumentation.java index 8e056299c92f..dc9f7b9f1b2d 100644 --- a/testing-common/src/test/java/context/ContextTestInstrumentation.java +++ b/testing-common/src/test/java/context/ContextTestInstrumentation.java @@ -57,6 +57,7 @@ public Map, String> transfor named("incrementContextCount"), StoreAndIncrementApiUsageAdvice.class.getName()); transformers.put(named("getContextCount"), GetApiUsageAdvice.class.getName()); transformers.put(named("putContextCount"), PutApiUsageAdvice.class.getName()); + transformers.put(named("removeContextCount"), RemoveApiUsageAdvice.class.getName()); transformers.put( named("incorrectKeyClassUsage"), IncorrectKeyClassContextApiUsageAdvice.class.getName()); transformers.put( @@ -116,7 +117,8 @@ public static void methodExit( @Advice.This KeyClass thiz, @Advice.Return(readOnly = false) int contextCount) { ContextStore contextStore = InstrumentationContext.get(KeyClass.class, Context.class); - contextCount = contextStore.get(thiz).count; + Context context = contextStore.get(thiz); + contextCount = context == null ? 0 : context.count; } } @@ -131,6 +133,15 @@ public static void methodExit(@Advice.This KeyClass thiz, @Advice.Argument(0) in } } + public static class RemoveApiUsageAdvice { + @Advice.OnMethodExit + public static void methodExit(@Advice.This KeyClass thiz) { + ContextStore contextStore = + InstrumentationContext.get(KeyClass.class, Context.class); + contextStore.put(thiz, null); + } + } + public static class IncorrectKeyClassContextApiUsageAdvice { @Advice.OnMethodExit public static void methodExit() { @@ -191,6 +202,10 @@ public int getContextCount() { public void putContextCount(int value) { // implementation replaced with test instrumentation } + + public void removeContextCount() { + // implementation replaced with test instrumentation + } } /**