diff --git a/sdk/core/azure-core/azure/core/tracing/context.py b/sdk/core/azure-core/azure/core/tracing/context.py index a00f88f2d854..f8761cd71502 100644 --- a/sdk/core/azure-core/azure/core/tracing/context.py +++ b/sdk/core/azure-core/azure/core/tracing/context.py @@ -29,8 +29,8 @@ class ContextProtocol(Protocol): Implements set and get variables in a thread safe way. """ - def __init__(self, name, default, lock): - # type: (string, Any, threading.Lock) -> None + def __init__(self, name, default): + # type: (string, Any) -> None pass def clear(self): @@ -54,11 +54,10 @@ class _AsyncContext(object): Uses contextvars to set and get variables globally in a thread safe way. """ - def __init__(self, name, default, lock): + def __init__(self, name, default): self.name = name self.contextvar = contextvars.ContextVar(name) self.default = default if callable(default) else (lambda: default) - self.lock = lock def clear(self): # type: () -> None @@ -78,8 +77,7 @@ def get(self): def set(self, value): # type: (Any) -> None """Set the value in the context.""" - with self.lock: - self.contextvar.set(value) + self.contextvar.set(value) class _ThreadLocalContext(object): @@ -88,11 +86,10 @@ class _ThreadLocalContext(object): """ _thread_local = threading.local() - def __init__(self, name, default, lock): - # type: (str, Any, threading.Lock) -> None + def __init__(self, name, default): + # type: (str, Any) -> None self.name = name self.default = default if callable(default) else (lambda: default) - self.lock = lock def clear(self): # type: () -> None @@ -112,16 +109,14 @@ def get(self): def set(self, value): # type: (Any) -> None """Set the value in the context.""" - with self.lock: - setattr(self._thread_local, self.name, value) + setattr(self._thread_local, self.name, value) -class TracingContext: - _lock = threading.Lock() - +class TracingContext(object): def __init__(self): # type: () -> None - self.current_span = TracingContext._get_context_class("current_span", None) + context_class = _AsyncContext if contextvars else _ThreadLocalContext + self.current_span = context_class("current_span", None) def with_current_context(self, func): # type: (Callable[[Any], Any]) -> Any @@ -146,17 +141,4 @@ def call_with_current_context(*args, **kwargs): return call_with_current_context - @classmethod - def _get_context_class(cls, name, default_val): - # type: (str, Any) -> ContextProtocol - """ - Returns an instance of the the context class that stores the variable. - :param name: The key to store the variable in the context class - :param default_val: The default value of the variable if unset - :return: An instance that implements the context protocol class - """ - context_class = _AsyncContext if contextvars else _ThreadLocalContext - return context_class(name, default_val, cls._lock) - - tracing_context = TracingContext() diff --git a/sdk/core/azure-core/tests/test_tracing_context.py b/sdk/core/azure-core/tests/test_tracing_context.py index 59aa2c8a8b60..984d2ad480d3 100644 --- a/sdk/core/azure-core/tests/test_tracing_context.py +++ b/sdk/core/azure-core/tests/test_tracing_context.py @@ -34,19 +34,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TestContext(unittest.TestCase): - def test_get_context_class(self): - with ContextHelper(): - slot = tracing_context._get_context_class("temp", 1) - assert slot.get() == 1 - slot.set(2) - assert slot.get() == 2 - def test_current_span(self): with ContextHelper(): - assert tracing_context.current_span.get() is None + assert not tracing_context.current_span.get() val = mock.Mock(spec=AbstractSpan) tracing_context.current_span.set(val) assert tracing_context.current_span.get() == val + tracing_context.current_span.clear() + assert not tracing_context.current_span.get() def test_with_current_context(self): with ContextHelper(tracer_to_use=mock.Mock(AbstractSpan)):