Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pymongo instrumentation hooks #793

Merged
merged 4 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#781](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/781))
- `opentelemetry-instrumentation-aws-lambda` Add instrumentation for AWS Lambda Service - Implementation (Part 2/2)
([#777](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/777))
- `opentelemetry-instrumentation-pymongo` Add `request_hook`, `response_hook` and `failed_hook` callbacks passed as arguments to the instrument method
([#793](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/793))
- `opentelemetry-instrumentation-pymysql` Add support for PyMySQL 1.x series
([#792](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/792))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pymongo import MongoClient
from opentelemetry.instrumentation.pymongo import PymongoInstrumentor


PymongoInstrumentor().instrument()
client = MongoClient()
db = client["MongoDB_Database"]
Expand All @@ -35,9 +34,47 @@

API
---
"""
The `instrument` method accepts the following keyword args:

tracer_provider (TracerProvider) - an optional tracer provider
request_hook (Callable) -
a function with extra user-defined logic to be performed before querying mongodb
this function signature is: def request_hook(span: Span, event: CommandStartedEvent) -> None
response_hook (Callable) -
a function with extra user-defined logic to be performed after the query returns with a successful response
this function signature is: def response_hook(span: Span, event: CommandSucceededEvent) -> None
failed_hook (Callable) -
a function with extra user-defined logic to be performed after the query returns with a failed response
this function signature is: def failed_hook(span: Span, event: CommandFailedEvent) -> None

for example:

.. code: python

from opentelemetry.instrumentation.pymongo import PymongoInstrumentor
from pymongo import MongoClient

def request_hook(span, event):
# request hook logic

from typing import Collection
def response_hook(span, event):
# response hook logic

def failed_hook(span, event):
# failed hook logic

# Instrument pymongo with hooks
PymongoInstrumentor().instrument(request_hook=request_hook, response_hooks=response_hook, failed_hook=failed_hook)

# This will create a span with pymongo specific attributes, including custom attributes added from the hooks
client = MongoClient()
db = client["MongoDB_Database"]
collection = db["MongoDB_Collection"]
collection.find_one()

"""
from logging import getLogger
from typing import Callable, Collection

from pymongo import monitoring

Expand All @@ -48,14 +85,34 @@
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.semconv.trace import DbSystemValues, SpanAttributes
from opentelemetry.trace import SpanKind, get_tracer
from opentelemetry.trace.span import Span
from opentelemetry.trace.status import Status, StatusCode

_LOG = getLogger(__name__)

RequestHookT = Callable[[Span, monitoring.CommandStartedEvent], None]
ResponseHookT = Callable[[Span, monitoring.CommandSucceededEvent], None]
FailedHookT = Callable[[Span, monitoring.CommandFailedEvent], None]


def dummy_callback(span, event):
...


class CommandTracer(monitoring.CommandListener):
def __init__(self, tracer):
def __init__(
self,
tracer,
request_hook: RequestHookT = dummy_callback,
response_hook: ResponseHookT = dummy_callback,
failed_hook: FailedHookT = dummy_callback,
):
self._tracer = tracer
self._span_dict = {}
self.is_enabled = True
self.start_hook = request_hook
self.success_hook = response_hook
self.failed_hook = failed_hook

def started(self, event: monitoring.CommandStartedEvent):
""" Method to handle a pymongo CommandStartedEvent """
Expand Down Expand Up @@ -85,6 +142,10 @@ def started(self, event: monitoring.CommandStartedEvent):
span.set_attribute(
SpanAttributes.NET_PEER_PORT, event.connection_id[1]
)
try:
self.start_hook(span, event)
except Exception as hook_exception: # noqa pylint: disable=broad-except
_LOG.exception(hook_exception)

# Add Span to dictionary
self._span_dict[_get_span_dict_key(event)] = span
Expand All @@ -103,6 +164,11 @@ def succeeded(self, event: monitoring.CommandSucceededEvent):
span = self._pop_span(event)
if span is None:
return
if span.is_recording():
try:
self.success_hook(span, event)
except Exception as hook_exception: # noqa pylint: disable=broad-except
_LOG.exception(hook_exception)
span.end()

def failed(self, event: monitoring.CommandFailedEvent):
Expand All @@ -116,6 +182,10 @@ def failed(self, event: monitoring.CommandFailedEvent):
return
if span.is_recording():
span.set_status(Status(StatusCode.ERROR, event.failure))
try:
self.failed_hook(span, event)
except Exception as hook_exception: # noqa pylint: disable=broad-except
_LOG.exception(hook_exception)
span.end()

def _pop_span(self, event):
Expand Down Expand Up @@ -150,12 +220,20 @@ def _instrument(self, **kwargs):
"""

tracer_provider = kwargs.get("tracer_provider")
request_hook = kwargs.get("request_hook", dummy_callback)
response_hook = kwargs.get("response_hook", dummy_callback)
failed_hook = kwargs.get("failed_hook", dummy_callback)

# Create and register a CommandTracer only the first time
if self._commandtracer_instance is None:
tracer = get_tracer(__name__, __version__, tracer_provider)

self._commandtracer_instance = CommandTracer(tracer)
self._commandtracer_instance = CommandTracer(
tracer,
request_hook=request_hook,
response_hook=response_hook,
failed_hook=failed_hook,
)
monitoring.register(self._commandtracer_instance)

# If already created, just enable it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class TestPymongo(TestBase):
def setUp(self):
super().setUp()
self.tracer = self.tracer_provider.get_tracer(__name__)
self.start_callback = mock.MagicMock()
self.success_callback = mock.MagicMock()
self.failed_callback = mock.MagicMock()

def test_pymongo_instrumentor(self):
mock_register = mock.Mock()
Expand All @@ -44,7 +47,9 @@ def test_started(self):
command_attrs = {
"command_name": "find",
}
command_tracer = CommandTracer(self.tracer)
command_tracer = CommandTracer(
self.tracer, request_hook=self.start_callback
)
mock_event = MockEvent(
command_attrs, ("test.com", "1234"), "test_request_id"
)
Expand All @@ -66,17 +71,24 @@ def test_started(self):
span.attributes[SpanAttributes.NET_PEER_NAME], "test.com"
)
self.assertEqual(span.attributes[SpanAttributes.NET_PEER_PORT], "1234")
self.start_callback.assert_called_once_with(span, mock_event)

def test_succeeded(self):
mock_event = MockEvent({})
command_tracer = CommandTracer(self.tracer)
command_tracer = CommandTracer(
self.tracer,
request_hook=self.start_callback,
response_hook=self.success_callback,
)
command_tracer.started(event=mock_event)
command_tracer.succeeded(event=mock_event)
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertIs(span.status.status_code, trace_api.StatusCode.UNSET)
self.assertIsNotNone(span.end_time)
self.start_callback.assert_called_once()
self.success_callback.assert_called_once()

def test_not_recording(self):
mock_tracer = mock.Mock()
Expand Down Expand Up @@ -119,7 +131,11 @@ def test_suppression_key(self):

def test_failed(self):
mock_event = MockEvent({})
command_tracer = CommandTracer(self.tracer)
command_tracer = CommandTracer(
self.tracer,
request_hook=self.start_callback,
failed_hook=self.failed_callback,
)
command_tracer.started(event=mock_event)
command_tracer.failed(event=mock_event)

Expand All @@ -132,6 +148,8 @@ def test_failed(self):
)
self.assertEqual(span.status.description, "failure")
self.assertIsNotNone(span.end_time)
self.start_callback.assert_called_once()
self.failed_callback.assert_called_once()

def test_multiple_commands(self):
first_mock_event = MockEvent({}, ("firstUrl", "123"), "first")
Expand Down