Skip to content

Commit

Permalink
Bugfix: Pika basicConsume context propagation (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
oxeye-yuval authored Oct 21, 2021
1 parent ae7a415 commit 3ff06da
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 79 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `opentelemetry-instrumentation-asgi` now explicitly depends on asgiref as it uses the package instead of instrumenting it.
([#765](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/765))
- `opentelemetry-instrumentation-pika` now propagates context to basic_consume callback
([#766](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/766))

## [1.6.2-0.25b2](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.6.2-0.25b2) - 2021-10-19

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from typing import Any, Callable, Collection, Dict, Optional
from typing import Any, Collection, Dict, Optional

import wrapt
from pika.adapters import BlockingConnection
from pika.channel import Channel
from pika.adapters.blocking_connection import BlockingChannel

from opentelemetry import trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
Expand All @@ -35,18 +35,25 @@
class PikaInstrumentor(BaseInstrumentor): # type: ignore
# pylint: disable=attribute-defined-outside-init
@staticmethod
def _instrument_consumers(
consumers_dict: Dict[str, Callable[..., Any]], tracer: Tracer
def _instrument_blocking_channel_consumers(
channel: BlockingChannel, tracer: Tracer
) -> Any:
for key, callback in consumers_dict.items():
for consumer_tag, consumer_info in channel._consumer_infos.items():
decorated_callback = utils._decorate_callback(
callback, tracer, key
consumer_info.on_message_callback, tracer, consumer_tag
)
setattr(decorated_callback, "_original_callback", callback)
consumers_dict[key] = decorated_callback

setattr(
decorated_callback,
"_original_callback",
consumer_info.on_message_callback,
)
consumer_info.on_message_callback = decorated_callback

@staticmethod
def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None:
def _instrument_basic_publish(
channel: BlockingChannel, tracer: Tracer
) -> None:
original_function = getattr(channel, "basic_publish")
decorated_function = utils._decorate_basic_publish(
original_function, channel, tracer
Expand All @@ -57,13 +64,13 @@ def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None:

@staticmethod
def _instrument_channel_functions(
channel: Channel, tracer: Tracer
channel: BlockingChannel, tracer: Tracer
) -> None:
if hasattr(channel, "basic_publish"):
PikaInstrumentor._instrument_basic_publish(channel, tracer)

@staticmethod
def _uninstrument_channel_functions(channel: Channel) -> None:
def _uninstrument_channel_functions(channel: BlockingChannel) -> None:
for function_name in _FUNCTIONS_TO_UNINSTRUMENT:
if not hasattr(channel, function_name):
continue
Expand All @@ -73,8 +80,10 @@ def _uninstrument_channel_functions(channel: Channel) -> None:
unwrap(channel, "basic_consume")

@staticmethod
# Make sure that the spans are created inside hash them set as parent and not as brothers
def instrument_channel(
channel: Channel, tracer_provider: Optional[TracerProvider] = None,
channel: BlockingChannel,
tracer_provider: Optional[TracerProvider] = None,
) -> None:
if not hasattr(channel, "_is_instrumented_by_opentelemetry"):
channel._is_instrumented_by_opentelemetry = False
Expand All @@ -84,18 +93,14 @@ def instrument_channel(
)
return
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
if not hasattr(channel, "_impl"):
_LOG.error("Could not find implementation for provided channel!")
return
if channel._impl._consumers:
PikaInstrumentor._instrument_consumers(
channel._impl._consumers, tracer
)
PikaInstrumentor._instrument_blocking_channel_consumers(
channel, tracer
)
PikaInstrumentor._decorate_basic_consume(channel, tracer)
PikaInstrumentor._instrument_channel_functions(channel, tracer)

@staticmethod
def uninstrument_channel(channel: Channel) -> None:
def uninstrument_channel(channel: BlockingChannel) -> None:
if (
not hasattr(channel, "_is_instrumented_by_opentelemetry")
or not channel._is_instrumented_by_opentelemetry
Expand All @@ -104,12 +109,12 @@ def uninstrument_channel(channel: Channel) -> None:
"Attempting to uninstrument Pika channel while already uninstrumented!"
)
return
if not hasattr(channel, "_impl"):
_LOG.error("Could not find implementation for provided channel!")
return
for key, callback in channel._impl._consumers.items():
if hasattr(callback, "_original_callback"):
channel._impl._consumers[key] = callback._original_callback

for consumers_tag, client_info in channel._consumer_infos.items():
if hasattr(client_info.on_message_callback, "_original_callback"):
channel._consumer_infos[
consumers_tag
] = client_info.on_message_callback._original_callback
PikaInstrumentor._uninstrument_channel_functions(channel)

def _decorate_channel_function(
Expand All @@ -123,28 +128,15 @@ def wrapper(wrapped, instance, args, kwargs):
wrapt.wrap_function_wrapper(BlockingConnection, "channel", wrapper)

@staticmethod
def _decorate_basic_consume(channel, tracer: Optional[Tracer]) -> None:
def _decorate_basic_consume(
channel: BlockingChannel, tracer: Optional[Tracer]
) -> None:
def wrapper(wrapped, instance, args, kwargs):
if not hasattr(channel, "_impl"):
_LOG.error(
"Could not find implementation for provided channel!"
)
return wrapped(*args, **kwargs)
current_keys = set(channel._impl._consumers.keys())
return_value = wrapped(*args, **kwargs)
new_key_list = list(
set(channel._impl._consumers.keys()) - current_keys
)
if not new_key_list:
_LOG.error("Could not find added callback")
return return_value
new_key = new_key_list[0]
callback = channel._impl._consumers[new_key]
decorated_callback = utils._decorate_callback(
callback, tracer, new_key

PikaInstrumentor._instrument_blocking_channel_consumers(
channel, tracer
)
setattr(decorated_callback, "_original_callback", callback)
channel._impl._consumers[new_key] = decorated_callback
return return_value

wrapt.wrap_function_wrapper(channel, "basic_consume", wrapper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,23 @@ def decorated_callback(
ctx = propagate.extract(properties.headers, getter=_pika_getter)
if not ctx:
ctx = context.get_current()
token = context.attach(ctx)
span = _get_span(
tracer,
channel,
properties,
destination=method.exchange
if method.exchange
else method.routing_key,
span_kind=SpanKind.CONSUMER,
task_name=task_name,
ctx=ctx,
operation=MessagingOperationValues.RECEIVE,
)
with trace.use_span(span, end_on_exit=True):
retval = callback(channel, method, properties, body)
try:
with trace.use_span(span, end_on_exit=True):
retval = callback(channel, method, properties, body)
finally:
context.detach(token)
return retval

return decorated_callback
Expand All @@ -78,14 +84,13 @@ def decorated_function(
properties = BasicProperties(headers={})
if properties.headers is None:
properties.headers = {}
ctx = context.get_current()
span = _get_span(
tracer,
channel,
properties,
destination=exchange if exchange else routing_key,
span_kind=SpanKind.PRODUCER,
task_name="(temporary)",
ctx=ctx,
operation=None,
)
if not span:
Expand All @@ -108,8 +113,8 @@ def _get_span(
channel: Channel,
properties: BasicProperties,
task_name: str,
destination: str,
span_kind: SpanKind,
ctx: context.Context,
operation: Optional[MessagingOperationValues] = None,
) -> Optional[Span]:
if context.get_value("suppress_instrumentation") or context.get_value(
Expand All @@ -118,9 +123,7 @@ def _get_span(
return None
task_name = properties.type if properties.type else task_name
span = tracer.start_span(
context=ctx,
name=_generate_span_name(task_name, operation),
kind=span_kind,
name=_generate_span_name(destination, operation), kind=span_kind,
)
if span.is_recording():
_enrich_span(span, channel, properties, task_name, operation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from unittest import TestCase, mock

from pika.adapters import BaseConnection, BlockingConnection
from pika.adapters import BlockingConnection
from pika.channel import Channel
from wrapt import BoundFunctionWrapper

Expand All @@ -24,9 +24,10 @@
class TestPika(TestCase):
def setUp(self) -> None:
self.channel = mock.MagicMock(spec=Channel)
self.channel._impl = mock.MagicMock(spec=BaseConnection)
consumer_info = mock.MagicMock()
consumer_info.on_message_callback = mock.MagicMock()
self.channel._consumer_infos = {"consumer-tag": consumer_info}
self.mock_callback = mock.MagicMock()
self.channel._impl._consumers = {"mock_key": self.mock_callback}

def test_instrument_api(self) -> None:
instrumentation = PikaInstrumentor()
Expand All @@ -49,19 +50,19 @@ def test_instrument_api(self) -> None:
"opentelemetry.instrumentation.pika.PikaInstrumentor._decorate_basic_consume"
)
@mock.patch(
"opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_consumers"
"opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_blocking_channel_consumers"
)
def test_instrument(
self,
instrument_consumers: mock.MagicMock,
instrument_blocking_channel_consumers: mock.MagicMock,
instrument_basic_consume: mock.MagicMock,
instrument_channel_functions: mock.MagicMock,
):
PikaInstrumentor.instrument_channel(channel=self.channel)
assert hasattr(
self.channel, "_is_instrumented_by_opentelemetry"
), "channel is not marked as instrumented!"
instrument_consumers.assert_called_once()
instrument_blocking_channel_consumers.assert_called_once()
instrument_basic_consume.assert_called_once()
instrument_channel_functions.assert_called_once()

Expand All @@ -71,18 +72,18 @@ def test_instrument_consumers(
) -> None:
tracer = mock.MagicMock(spec=Tracer)
expected_decoration_calls = [
mock.call(value, tracer, key)
for key, value in self.channel._impl._consumers.items()
mock.call(value.on_message_callback, tracer, key)
for key, value in self.channel._consumer_infos.items()
]
PikaInstrumentor._instrument_consumers(
self.channel._impl._consumers, tracer
PikaInstrumentor._instrument_blocking_channel_consumers(
self.channel, tracer
)
decorate_callback.assert_has_calls(
calls=expected_decoration_calls, any_order=True
)
assert all(
hasattr(callback, "_original_callback")
for callback in self.channel._impl._consumers.values()
for callback in self.channel._consumer_infos.values()
)

@mock.patch(
Expand Down
Loading

0 comments on commit 3ff06da

Please sign in to comment.