From a8dc9b39de7845a7e5a1cadf29adfe841b08ef07 Mon Sep 17 00:00:00 2001 From: Yusuke Tsutsumi Date: Thu, 15 Aug 2019 20:35:38 -0700 Subject: [PATCH] Changed to the propagator API Adding a UnifiedContext, composing DistributedContext and SpanContext. This will enable propagators to extract and inject values from either system, enabling more sophisticated schemes and standards to propagate data. This also removes the need for generics and propagators that only consume one or the other, requiring integrators to do extra work to wire propagators appropriately. Modifying the API of the propagators to consume the context as a mutable argument. By passing in the context rather than returning, this enables the chained use of propagators, allowing for situations such as supporting multiple trace propagation standards simulatenously. --- .../src/opentelemetry/context/__init__.py | 17 ++-- .../src/opentelemetry/context/base_context.py | 6 +- .../context/propagation/binaryformat.py | 19 ++-- .../context/propagation/httptextformat.py | 27 ++++-- .../opentelemetry/context/unified_context.py | 65 +++++++++++++ .../sdk/context/propagation/b3_format.py | 6 +- .../context/propagation/test_b3_format.py | 92 ++++++++++--------- 7 files changed, 154 insertions(+), 78 deletions(-) create mode 100644 opentelemetry-api/src/opentelemetry/context/unified_context.py diff --git a/opentelemetry-api/src/opentelemetry/context/__init__.py b/opentelemetry-api/src/opentelemetry/context/__init__.py index cf6c72dd8da..368b55affc9 100644 --- a/opentelemetry-api/src/opentelemetry/context/__init__.py +++ b/opentelemetry-api/src/opentelemetry/context/__init__.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """ The OpenTelemetry context module provides abstraction layer on top of thread-local storage and contextvars. The long term direction is to switch to contextvars provided by the Python runtime library. A global object ``Context`` is provided to access all the context related -functionalities:: +functionalities: >>> from opentelemetry.context import Context >>> Context.foo = 1 @@ -27,9 +25,8 @@ >>> Context.foo 2 -When explicit thread is used, a helper function -``Context.with_current_context`` can be used to carry the context across -threads:: +When explicit thread is used, a helper function `Context.with_current_context` +can be used to carry the context across threads: from threading import Thread from opentelemetry.context import Context @@ -62,7 +59,7 @@ def work(name): print('Main thread:', Context) -Here goes another example using thread pool:: +Here goes another example using thread pool: import time import threading @@ -97,7 +94,7 @@ def work(name): pool.join() println('Main thread: {}'.format(Context)) -Here goes a simple demo of how async could work in Python 3.7+:: +Here goes a simple demo of how async could work in Python 3.7+: import asyncio @@ -141,9 +138,9 @@ async def main(): import typing from .base_context import BaseRuntimeContext +from .unified_context import UnifiedContext -__all__ = ["Context"] - +__all__ = ["Context", "UnifiedContext"] Context = None # type: typing.Optional[BaseRuntimeContext] diff --git a/opentelemetry-api/src/opentelemetry/context/base_context.py b/opentelemetry-api/src/opentelemetry/context/base_context.py index f1e37aa91f4..7892af54f8c 100644 --- a/opentelemetry-api/src/opentelemetry/context/base_context.py +++ b/opentelemetry-api/src/opentelemetry/context/base_context.py @@ -37,7 +37,7 @@ def set(self, value: "object") -> None: raise NotImplementedError _lock = threading.Lock() - _slots = {} # type: typing.Dict[str, 'BaseRuntimeContext.Slot'] + _slots: typing.Dict[str, Slot] = {} @classmethod def clear(cls) -> None: @@ -48,9 +48,7 @@ def clear(cls) -> None: slot.clear() @classmethod - def register_slot( - cls, name: str, default: "object" = None - ) -> "BaseRuntimeContext.Slot": + def register_slot(cls, name: str, default: "object" = None) -> "Slot": """Register a context slot with an optional default value. :type name: str diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py index 7f1a65882f3..dbec8f5af49 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py +++ b/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py @@ -13,9 +13,8 @@ # limitations under the License. import abc -import typing -from opentelemetry.trace import SpanContext +from opentelemetry.context import UnifiedContext class BinaryFormat(abc.ABC): @@ -27,14 +26,14 @@ class BinaryFormat(abc.ABC): @staticmethod @abc.abstractmethod - def to_bytes(context: SpanContext) -> bytes: + def to_bytes(context: UnifiedContext) -> bytes: """Creates a byte representation of a SpanContext. to_bytes should read values from a SpanContext and return a data format to represent it, in bytes. Args: - context: the SpanContext to serialize + context: the SpanContext to serialize. Returns: A bytes representation of the SpanContext. @@ -43,15 +42,17 @@ def to_bytes(context: SpanContext) -> bytes: @staticmethod @abc.abstractmethod - def from_bytes(byte_representation: bytes) -> typing.Optional[SpanContext]: - """Return a SpanContext that was represented by bytes. + def from_bytes(context: UnifiedContext, + byte_representation: bytes) -> None: + """Populate UnifiedContext that was represented by bytes. - from_bytes should return back a SpanContext that was constructed from - the data serialized in the byte_representation passed. If it is not + from_bytes should populated UnifiedContext with data that was + serialized in the byte_representation passed. If it is not possible to read in a proper SpanContext, return None. Args: - byte_representation: the bytes to deserialize + context: The UnifiedContext to populate. + byte_representation: the bytes to deserialize. Returns: A bytes representation of the SpanContext if it is valid. diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py index f3823a86d17..c8598794b61 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py +++ b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py @@ -15,6 +15,7 @@ import abc import typing +from opentelemetry.context import UnifiedContext from opentelemetry.trace import SpanContext Setter = typing.Callable[[object, str, str], None] @@ -35,11 +36,12 @@ class HTTPTextFormat(abc.ABC): import flask import requests from opentelemetry.context.propagation import HTTPTextFormat + from opentelemetry.trace import tracer + from opentelemetry.context import UnifiedContext PROPAGATOR = HTTPTextFormat() - def get_header_from_flask_request(request, key): return request.headers.get_all(key) @@ -48,15 +50,17 @@ def set_header_into_requests_request(request: requests.Request, request.headers[key] = value def example_route(): - span_context = PROPAGATOR.extract( - get_header_from_flask_request, + span = tracer().create_span("") + context = UnifiedContext.create(span) + PROPAGATOR.extract( + context, get_header_from_flask_request, flask.request ) request_to_downstream = requests.Request( "GET", "http://httpbin.org/get" ) PROPAGATOR.inject( - span_context, + context, set_header_into_requests_request, request_to_downstream ) @@ -70,15 +74,20 @@ def example_route(): @abc.abstractmethod def extract( - self, get_from_carrier: Getter, carrier: object - ) -> SpanContext: - """Create a SpanContext from values in the carrier. + self, + context: UnifiedContext, + get_from_carrier: Getter, + carrier: object, + ) -> None: + """Extract values from the carrier into the context. The extract function should retrieve values from the carrier - object using get_from_carrier, and use values to populate a - SpanContext value and return it. + object using get_from_carrier, and use values to populate + attributes of the UnifiedContext passed in. Args: + context: A UnifiedContext instance that will be + populated with values from the carrier. get_from_carrier: a function that can retrieve zero or more values from the carrier. In the case that the value does not exist, return an empty list. diff --git a/opentelemetry-api/src/opentelemetry/context/unified_context.py b/opentelemetry-api/src/opentelemetry/context/unified_context.py new file mode 100644 index 00000000000..7da5200d254 --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/context/unified_context.py @@ -0,0 +1,65 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from opentelemetry.distributedcontext import DistributedContext +from opentelemetry.trace import SpanContext + + +class UnifiedContext: + """A unified context object that contains all context relevant to + telemetry. + + The UnifiedContext is a single object that composes all contexts that + are needed by the various forms of telemetry. It is intended to be an + object that can be passed as the argument to any component that needs + to read or modify content values (such as propagators). By unifying + all context in a composed data structure, it expands the flexibility + of the APIs that modify it. + + As it is designed to carry context specific to all telemetry use + cases, it's schema is explicit. Note that this is not intended to + be an object that acts as a singleton that returns different results + based on the thread or coroutine of execution. For that, see `Context`. + + + Args: + distributed: The DistributedContext for this instance. + span: The SpanContext for this instance. + """ + __slots__ = ["distributed", "span"] + + def __init__(self, distributed: DistributedContext, span: SpanContext): + self.distributed = distributed + self.span = span + + @staticmethod + def create(span: SpanContext) -> "UnifiedContext": + """Create an unpopulated UnifiedContext object. + + Example: + + from opentelemetry.trace import tracer + span = tracer.create_span("") + context = UnifiedContext.create(span) + + + Args: + parent_span: the parent SpanContext that will be the + parent of the span in the UnifiedContext. + """ + return UnifiedContext(DistributedContext(), span) + + def __repr__(self) -> str: + return "{}(distributed={}, span={})".format( + type(self).__name__, repr(self.distributed), repr(self.span)) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py index 72d02d60700..ae159872ce7 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py @@ -32,7 +32,7 @@ class B3Format(HTTPTextFormat): _SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"]) @classmethod - def extract(cls, get_from_carrier, carrier): + def extract(cls, context, get_from_carrier, carrier): trace_id = format_trace_id(trace.INVALID_TRACE_ID) span_id = format_span_id(trace.INVALID_SPAN_ID) sampled = 0 @@ -57,7 +57,7 @@ def extract(cls, get_from_carrier, carrier): elif len(fields) == 4: trace_id, span_id, sampled, _parent_span_id = fields else: - return trace.INVALID_SPAN_CONTEXT + return else: trace_id = ( _extract_first_element( @@ -92,7 +92,7 @@ def extract(cls, get_from_carrier, carrier): if sampled in cls._SAMPLE_PROPAGATE_VALUES or flags == "1": options |= trace.TraceOptions.RECORDED - return trace.SpanContext( + context.span = trace.SpanContext( # trace an span ids are encoded in hex, so must be converted trace_id=int(trace_id, 16), span_id=int(span_id, 16), diff --git a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py index 42ff3410f06..613d9b6d307 100644 --- a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py @@ -17,6 +17,9 @@ import opentelemetry.sdk.context.propagation.b3_format as b3_format import opentelemetry.sdk.trace as trace +from opentelemetry.context import UnifiedContext +from opentelemetry.sdk.trace import tracer + FORMAT = b3_format.B3Format() @@ -35,6 +38,11 @@ def setUpClass(cls): trace.generate_span_id() ) + def setUp(self): + span_context = tracer.create_span("").context + self.context = UnifiedContext.create(span_context) + self.carrier = {} + def test_extract_multi_header(self): """Test the extraction of B3 headers.""" carrier = { @@ -42,16 +50,15 @@ def test_extract_multi_header(self): FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) - new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) + FORMAT.extract(self.context, get_as_list, carrier) + FORMAT.inject(self.context, dict.__setitem__, self.carrier) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id + self.carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], self.serialized_span_id + self.carrier[FORMAT.SPAN_ID_KEY], self.serialized_span_id ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(self.carrier[FORMAT.SAMPLED_KEY], "1") def test_extract_single_header(self): """Test the extraction from a single b3 header.""" @@ -60,16 +67,15 @@ def test_extract_single_header(self): self.serialized_trace_id, self.serialized_span_id ) } - span_context = FORMAT.extract(get_as_list, carrier) - new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) + FORMAT.extract(self.context, get_as_list, carrier) + FORMAT.inject(self.context, dict.__setitem__, self.carrier) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id + self.carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], self.serialized_span_id + self.carrier[FORMAT.SPAN_ID_KEY], self.serialized_span_id ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(self.carrier[FORMAT.SAMPLED_KEY], "1") def test_extract_header_precedence(self): """A single b3 header should take precedence over multiple @@ -84,11 +90,10 @@ def test_extract_header_precedence(self): FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) - new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) + FORMAT.extract(self.context, get_as_list, carrier) + FORMAT.inject(self.context, dict.__setitem__, self.carrier) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id + self.carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id ) def test_enabled_sampling(self): @@ -99,10 +104,9 @@ def test_enabled_sampling(self): FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: variant, } - span_context = FORMAT.extract(get_as_list, carrier) - new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + FORMAT.extract(self.context, get_as_list, carrier) + FORMAT.inject(self.context, dict.__setitem__, self.carrier) + self.assertEqual(self.carrier[FORMAT.SAMPLED_KEY], "1") def test_disabled_sampling(self): """Test b3 sample key variants that turn off sampling.""" @@ -112,10 +116,9 @@ def test_disabled_sampling(self): FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.SAMPLED_KEY: variant, } - span_context = FORMAT.extract(get_as_list, carrier) - new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "0") + FORMAT.extract(self.context, get_as_list, carrier) + FORMAT.inject(self.context, dict.__setitem__, self.carrier) + self.assertEqual(self.carrier[FORMAT.SAMPLED_KEY], "0") def test_flags(self): """x-b3-flags set to "1" should result in propagation.""" @@ -124,10 +127,9 @@ def test_flags(self): FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) - new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + FORMAT.extract(self.context, get_as_list, carrier) + FORMAT.inject(self.context, dict.__setitem__, self.carrier) + self.assertEqual(self.carrier[FORMAT.SAMPLED_KEY], "1") def test_flags_and_sampling(self): """Propagate if b3 flags and sampling are set.""" @@ -136,10 +138,9 @@ def test_flags_and_sampling(self): FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) - new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + FORMAT.extract(self.context, get_as_list, carrier) + FORMAT.inject(self.context, dict.__setitem__, self.carrier) + self.assertEqual(self.carrier[FORMAT.SAMPLED_KEY], "1") def test_64bit_trace_id(self): """64 bit trace ids should be padded to 128 bit trace ids.""" @@ -149,21 +150,24 @@ def test_64bit_trace_id(self): FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) - new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) + FORMAT.extract(self.context, get_as_list, carrier) + FORMAT.inject(self.context, dict.__setitem__, self.carrier) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit + self.carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit ) def test_invalid_single_header(self): """If an invalid single header is passed, return an invalid SpanContext. """ + self.context.span.trace_id = api_trace.INVALID_TRACE_ID + self.context.span.span_id = api_trace.INVALID_SPAN_ID carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - span_context = FORMAT.extract(get_as_list, carrier) - self.assertEqual(span_context.trace_id, api_trace.INVALID_TRACE_ID) - self.assertEqual(span_context.span_id, api_trace.INVALID_SPAN_ID) + FORMAT.extract(self.context, get_as_list, carrier) + self.assertEqual( + self.context.span.trace_id, api_trace.INVALID_TRACE_ID + ) + self.assertEqual(self.context.span.span_id, api_trace.INVALID_SPAN_ID) def test_missing_trace_id(self): """If a trace id is missing, populate an invalid trace id.""" @@ -171,8 +175,10 @@ def test_missing_trace_id(self): FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) - self.assertEqual(span_context.trace_id, api_trace.INVALID_TRACE_ID) + FORMAT.extract(self.context, get_as_list, carrier) + self.assertEqual( + self.context.span.trace_id, api_trace.INVALID_TRACE_ID + ) def test_missing_span_id(self): """If a trace id is missing, populate an invalid trace id.""" @@ -180,5 +186,5 @@ def test_missing_span_id(self): FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) - self.assertEqual(span_context.span_id, api_trace.INVALID_SPAN_ID) + FORMAT.extract(self.context, get_as_list, carrier) + self.assertEqual(self.context.span.span_id, api_trace.INVALID_SPAN_ID)