From d0d2422710b5aedbae0c437d040c06e1b5bb1ebb Mon Sep 17 00:00:00 2001 From: Phillip Verheyden Date: Sat, 27 Nov 2021 23:10:41 -0600 Subject: [PATCH] Add traceresponse headers for asgi apps (FastAPI, Starlette) This asgi version is modeled after the original wsgi version in #436 and corresponds to the SERVER span. Also cleans up some of the existing ASGI functionality to reduce complexity and make future contributions more straightforward. --- CHANGELOG.md | 7 + .../instrumentation/asgi/__init__.py | 141 ++++++++++++------ .../tests/test_asgi_middleware.py | 78 ++++++++++ 3 files changed, 184 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82daa12a0b..c543bf5764 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.8.0-0.27b0...HEAD) +### Added + +- `opentelemetry-instrumentation-asgi` now returns a `traceresponse` response header. + ([#817](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/817)) + ### Fixed - `opentelemetry-instrumentation-flask` Flask: Conditionally create SERVER spans @@ -28,6 +33,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [1.7.1-0.26b1](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.7.0-0.26b0) - 2021-11-11 +### Added + - `opentelemetry-instrumentation-aws-lambda` Add instrumentation for AWS Lambda Service - pkg metadata files (Part 1/2) ([#739](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/739)) - Add support for Python 3.10 diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index 0e1d3b7dc5..a953165473 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -103,11 +103,14 @@ def client_response_hook(span: Span, message: dict): from opentelemetry import context, trace from opentelemetry.instrumentation.asgi.version import __version__ # noqa +from opentelemetry.instrumentation.propagators import ( + get_global_response_propagator, +) from opentelemetry.instrumentation.utils import http_status_to_status_code from opentelemetry.propagate import extract -from opentelemetry.propagators.textmap import Getter +from opentelemetry.propagators.textmap import Getter, Setter from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import Span +from opentelemetry.trace import Span, set_span_in_context from opentelemetry.trace.status import Status, StatusCode from opentelemetry.util.http import remove_url_credentials @@ -152,6 +155,30 @@ def keys(self, carrier: dict) -> typing.List[str]: asgi_getter = ASGIGetter() +class ASGISetter(Setter): + def set( + self, carrier: dict, key: str, value: str + ) -> None: # pylint: disable=no-self-use + """Sets response header values on an ASGI scope according to `the spec `_. + + Args: + carrier: ASGI scope object + key: response header name to set + value: response header value + Returns: + None + """ + headers = carrier.get("headers") + if not headers: + headers = [] + carrier["headers"] = headers + + headers.append([key.lower().encode(), value.encode()]) + + +asgi_setter = ASGISetter() + + def collect_request_attributes(scope): """Collects HTTP request attributes from the ASGI scope and returns a dictionary to be used as span creation attributes.""" @@ -295,54 +322,84 @@ async def __call__(self, scope, receive, send): return await self.app(scope, receive, send) token = context.attach(extract(scope, getter=asgi_getter)) - span_name, additional_attributes = self.default_span_details(scope) + server_span_name, additional_attributes = self.default_span_details( + scope + ) try: with self.tracer.start_as_current_span( - span_name, + server_span_name, kind=trace.SpanKind.SERVER, - ) as span: - if span.is_recording(): + ) as server_span: + if server_span.is_recording(): attributes = collect_request_attributes(scope) attributes.update(additional_attributes) for key, value in attributes.items(): - span.set_attribute(key, value) + server_span.set_attribute(key, value) if callable(self.server_request_hook): - self.server_request_hook(span, scope) - - @wraps(receive) - async def wrapped_receive(): - with self.tracer.start_as_current_span( - " ".join((span_name, scope["type"], "receive")) - ) as receive_span: - if callable(self.client_request_hook): - self.client_request_hook(receive_span, scope) - message = await receive() - if receive_span.is_recording(): - if message["type"] == "websocket.receive": - set_status_code(receive_span, 200) - receive_span.set_attribute("type", message["type"]) - return message - - @wraps(send) - async def wrapped_send(message): - with self.tracer.start_as_current_span( - " ".join((span_name, scope["type"], "send")) - ) as send_span: - if callable(self.client_response_hook): - self.client_response_hook(send_span, message) - if send_span.is_recording(): - if message["type"] == "http.response.start": - status_code = message["status"] - set_status_code(span, status_code) - set_status_code(send_span, status_code) - elif message["type"] == "websocket.send": - set_status_code(span, 200) - set_status_code(send_span, 200) - send_span.set_attribute("type", message["type"]) - await send(message) - - await self.app(scope, wrapped_receive, wrapped_send) + self.server_request_hook(server_span, scope) + + otel_receive = self._get_otel_receive( + server_span_name, scope, receive + ) + + otel_send = self._get_otel_send( + server_span, + server_span_name, + scope, + send, + ) + + await self.app(scope, otel_receive, otel_send) finally: context.detach(token) + + def _get_otel_receive(self, server_span_name, scope, receive): + @wraps(receive) + async def otel_receive(): + with self.tracer.start_as_current_span( + " ".join((server_span_name, scope["type"], "receive")) + ) as receive_span: + if callable(self.client_request_hook): + self.client_request_hook(receive_span, scope) + message = await receive() + if receive_span.is_recording(): + if message["type"] == "websocket.receive": + set_status_code(receive_span, 200) + receive_span.set_attribute("type", message["type"]) + return message + + return otel_receive + + def _get_otel_send(self, server_span, server_span_name, scope, send): + @wraps(send) + async def otel_send(message): + with self.tracer.start_as_current_span( + " ".join((server_span_name, scope["type"], "send")) + ) as send_span: + if callable(self.client_response_hook): + self.client_response_hook(send_span, message) + if send_span.is_recording(): + if message["type"] == "http.response.start": + status_code = message["status"] + set_status_code(server_span, status_code) + set_status_code(send_span, status_code) + elif message["type"] == "websocket.send": + set_status_code(server_span, 200) + set_status_code(send_span, 200) + send_span.set_attribute("type", message["type"]) + + propagator = get_global_response_propagator() + if propagator: + propagator.inject( + message, + context=set_span_in_context( + server_span, trace.context_api.Context() + ), + setter=asgi_setter, + ) + + await send(message) + + return otel_send diff --git a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py index baeb6dd94e..aa33c34894 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_middleware.py @@ -18,6 +18,11 @@ import opentelemetry.instrumentation.asgi as otel_asgi from opentelemetry import trace as trace_api +from opentelemetry.instrumentation.propagators import ( + TraceResponsePropagator, + get_global_response_propagator, + set_global_response_propagator, +) from opentelemetry.sdk import resources from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.test.asgitestutil import ( @@ -25,6 +30,7 @@ setup_testing_defaults, ) from opentelemetry.test.test_base import TestBase +from opentelemetry.trace import format_span_id, format_trace_id async def http_app(scope, receive, send): @@ -287,6 +293,38 @@ def update_expected_user_agent(expected): outputs = self.get_all_output() self.validate_outputs(outputs, modifiers=[update_expected_user_agent]) + def test_traceresponse_header(self): + """Test a traceresponse header is sent when a global propagator is set.""" + + orig = get_global_response_propagator() + set_global_response_propagator(TraceResponsePropagator()) + + app = otel_asgi.OpenTelemetryMiddleware(simple_asgi) + self.seed_app(app) + self.send_default_request() + + span = self.memory_exporter.get_finished_spans()[-1] + self.assertEqual(trace_api.SpanKind.SERVER, span.kind) + + response_start, response_body, *_ = self.get_all_output() + self.assertEqual(response_body["body"], b"*") + self.assertEqual(response_start["status"], 200) + + traceresponse = "00-{0}-{1}-01".format( + format_trace_id(span.get_span_context().trace_id), + format_span_id(span.get_span_context().span_id), + ) + self.assertListEqual( + response_start["headers"], + [ + [b"Content-Type", b"text/plain"], + [b"traceresponse", f"{traceresponse}".encode()], + [b"access-control-expose-headers", b"traceresponse"], + ], + ) + + set_global_response_propagator(orig) + def test_websocket(self): self.scope = { "type": "websocket", @@ -359,6 +397,46 @@ def test_websocket(self): self.assertEqual(span.kind, expected["kind"]) self.assertDictEqual(dict(span.attributes), expected["attributes"]) + def test_websocket_traceresponse_header(self): + """Test a traceresponse header is set for websocket messages""" + + orig = get_global_response_propagator() + set_global_response_propagator(TraceResponsePropagator()) + + self.scope = { + "type": "websocket", + "http_version": "1.1", + "scheme": "ws", + "path": "/", + "query_string": b"", + "headers": [], + "client": ("127.0.0.1", 32767), + "server": ("127.0.0.1", 80), + } + app = otel_asgi.OpenTelemetryMiddleware(simple_asgi) + self.seed_app(app) + self.send_input({"type": "websocket.connect"}) + self.send_input({"type": "websocket.receive", "text": "ping"}) + self.send_input({"type": "websocket.disconnect"}) + _, socket_send, *_ = self.get_all_output() + + span = self.memory_exporter.get_finished_spans()[-1] + self.assertEqual(trace_api.SpanKind.SERVER, span.kind) + + traceresponse = "00-{0}-{1}-01".format( + format_trace_id(span.get_span_context().trace_id), + format_span_id(span.get_span_context().span_id), + ) + self.assertListEqual( + socket_send["headers"], + [ + [b"traceresponse", f"{traceresponse}".encode()], + [b"access-control-expose-headers", b"traceresponse"], + ], + ) + + set_global_response_propagator(orig) + def test_lifespan(self): self.scope["type"] = "lifespan" app = otel_asgi.OpenTelemetryMiddleware(simple_asgi)