From 837770ffc7ba56d1885e64835182c88c5b031bd8 Mon Sep 17 00:00:00 2001 From: Michael Stella Date: Tue, 15 Dec 2020 17:46:44 -0500 Subject: [PATCH] Bugfix for #257 - properly stream responses Fixes #257 --- .../instrumentation/grpc/__init__.py | 27 +++++++++++++---- .../instrumentation/grpc/_server.py | 30 +++++++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py index 955ef22ac9..d1ef5bbf9f 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py @@ -144,6 +144,9 @@ def serve(): # pylint:disable=unused-argument # isort:skip +import logging +log = logging.getLogger(__name__) + class GrpcInstrumentorServer(BaseInstrumentor): """ @@ -184,19 +187,32 @@ class GrpcInstrumentorClient(BaseInstrumentor): grpc_client_instrumentor = GrpcInstrumentorClient() grpc.client_instrumentor.instrument() + Instrumetor arguments: + wrap_secure (bool): False to disable wrapping secure channels + wrap_insecure (bool): False to disable wrapping insecure channels + exporter: OpenTelemetry metrics exporter + interval (int): metrics export interval + """ def _instrument(self, **kwargs): exporter = kwargs.get("exporter", None) interval = kwargs.get("interval", 30) - if kwargs.get("channel_type") == "secure": + + # preserve the old argument + if "wrap_secure" not in kwargs and kwargs.get("channel_type", "") == "secure": + kwargs["wrap_secure"] = True + kwargs["wrap_insecure"] = False + + if kwargs.get("wrap_secure", True): + log.info("wrapping secure channels") _wrap( "grpc", "secure_channel", partial(self.wrapper_fn, exporter, interval), ) - - else: + if kwargs.get("wrap_insecure", True): + log.info("wrapping insecure channels") _wrap( "grpc", "insecure_channel", @@ -204,10 +220,11 @@ def _instrument(self, **kwargs): ) def _uninstrument(self, **kwargs): - if kwargs.get("channel_type") == "secure": + if kwargs.get("wrap_secure", True): unwrap(grpc, "secure_channel") - else: + #else: + if kwargs.get("wrap_insecure", True): unwrap(grpc, "insecure_channel") def wrapper_fn( diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py index 3fe859f574..9895b8853f 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py @@ -239,6 +239,15 @@ def intercept_service(self, continuation, handler_call_details): def telemetry_wrapper(behavior, request_streaming, response_streaming): def telemetry_interceptor(request_or_iterator, context): + # handle streaming responses specially + if response_streaming: + return self._intercept_server_stream( + behavior, + handler_call_details, + request_or_iterator, + context, + ) + with self._set_remote_context(context): with self._start_span( handler_call_details, context @@ -249,6 +258,7 @@ def telemetry_interceptor(request_or_iterator, context): # And now we run the actual RPC. try: return behavior(request_or_iterator, context) + except Exception as error: # Bare exceptions are likely to be gRPC aborts, which # we handle in our context wrapper. @@ -263,3 +273,23 @@ def telemetry_interceptor(request_or_iterator, context): return _wrap_rpc_behavior( continuation(handler_call_details), telemetry_wrapper ) + + # Handle streaming responses separately - we have to do this + # to return a *new* generator or various upstream things + # get confused, or we'll lose the consistent trace + def _intercept_server_stream( + self, behavior, handler_call_details, request_or_iterator, context + ): + + with self._set_remote_context(context): + with self._start_span(handler_call_details, context) as span: + context = _OpenTelemetryServicerContext(context, span) + + try: + yield from behavior(request_or_iterator, context) + + except Exception as error: + # pylint:disable=unidiomatic-typecheck + if type(error) != Exception: + span.record_exception(error) + raise error