Skip to content

Commit

Permalink
refactor: remove redundant class and simplify capture OTel context usage
Browse files Browse the repository at this point in the history
  • Loading branch information
changemyminds committed Mar 15, 2024
1 parent 77ce105 commit 7949ac2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,16 @@ def __wrap_threading_run(call_wrapped, instance, args, kwargs):
def __wrap_thread_pool_submit(call_wrapped, instance, args, kwargs):
# obtain the original function and wrapped kwargs
original_func = args[0]
wrapped_kwargs = {
ThreadingInstrumentor.__WRAPPER_KWARGS: kwargs,
ThreadingInstrumentor.__WRAPPER_CONTEXT: context.get_current(),
}
otel_context = context.get_current()

def wrapped_func(*func_args, **func_kwargs):
original_kwargs = func_kwargs.pop(
ThreadingInstrumentor.__WRAPPER_KWARGS
)
otel_context = func_kwargs.pop(
ThreadingInstrumentor.__WRAPPER_CONTEXT
)
token = None
try:
token = context.attach(otel_context)
return original_func(*func_args, **original_kwargs)
return original_func(*func_args, **func_kwargs)
finally:
context.detach(token)

# replace the original function with the wrapped function
new_args = (wrapped_func,) + args[1:]
return call_wrapped(*new_args, **wrapped_kwargs)
return call_wrapped(*new_args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,18 @@

import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import List

from opentelemetry import trace
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
from opentelemetry.test.test_base import TestBase


@dataclass
class MockContext:
span_context: trace.SpanContext = None
trace_id: int = None
span_id: int = None


class TestThreading(TestBase):
def setUp(self):
super().setUp()
self._tracer = self.tracer_provider.get_tracer(__name__)
self._mock_contexts: List[MockContext] = []
self._mock_span_contexts: List[trace.SpanContext] = []
ThreadingInstrumentor().instrument()

def tearDown(self):
Expand All @@ -53,56 +45,58 @@ def test_trace_context_propagation_in_timer(self):

def run_threading_test(self, thread: threading.Thread):
with self.get_root_span() as span:
span_context = span.get_span_context()
expected_context = span_context
expected_trace_id = span_context.trace_id
expected_span_id = span_context.span_id
expected_span_context = span.get_span_context()
thread.start()
thread.join()

# check result
self.assertEqual(len(self._mock_contexts), 1)

current_mock_context = self._mock_contexts[0]
self.assertEqual(
current_mock_context.span_context, expected_context
self.assertEqual(len(self._mock_span_contexts), 1)
self.assert_span_context_equality(
self._mock_span_contexts[0], expected_span_context
)
self.assertEqual(current_mock_context.trace_id, expected_trace_id)
self.assertEqual(current_mock_context.span_id, expected_span_id)

def test_trace_context_propagation_in_thread_pool(self):
max_workers = 10
executor = ThreadPoolExecutor(max_workers=max_workers)

expected_contexts: List[trace.SpanContext] = []
expected_span_contexts: List[trace.SpanContext] = []
futures_list = []
for num in range(max_workers):
with self._tracer.start_as_current_span(f"trace_{num}") as span:
span_context = span.get_span_context()
expected_contexts.append(span_context)
expected_span_context = span.get_span_context()
expected_span_contexts.append(expected_span_context)
future = executor.submit(self.fake_func)
futures_list.append(future)

for future in as_completed(futures_list):
future.result()

# check result
self.assertEqual(len(self._mock_contexts), max_workers)
self.assertEqual(len(self._mock_contexts), len(expected_contexts))
for index, mock_context in enumerate(self._mock_contexts):
span_context = expected_contexts[index]
self.assertEqual(mock_context.span_context, span_context)
self.assertEqual(mock_context.trace_id, span_context.trace_id)
self.assertEqual(mock_context.span_id, span_context.span_id)
self.assertEqual(len(self._mock_span_contexts), max_workers)
self.assertEqual(
len(self._mock_span_contexts), len(expected_span_contexts)
)
for index, mock_span_context in enumerate(self._mock_span_contexts):
self.assert_span_context_equality(
mock_span_context, expected_span_contexts[index]
)

def fake_func(self):
span_context = trace.get_current_span().get_span_context()
mock_context = MockContext(
span_context=span_context,
trace_id=span_context.trace_id,
span_id=span_context.span_id,
self._mock_span_contexts.append(span_context)

def assert_span_context_equality(
self,
result_span_context: trace.SpanContext,
expected_span_context: trace.SpanContext,
):
self.assertEqual(result_span_context, expected_span_context)
self.assertEqual(
result_span_context.trace_id, expected_span_context.trace_id
)
self.assertEqual(
result_span_context.span_id, expected_span_context.span_id
)
self._mock_contexts.append(mock_context)

def print_square(self, num):
with self._tracer.start_as_current_span("square"):
Expand Down

0 comments on commit 7949ac2

Please sign in to comment.