Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding multiple db connections support for django-instrumentation's sqlcommenter #1187

Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.12.0rc2-0.32b0...HEAD)
- Adding multiple db connections support for django-instrumentation's sqlcommenter
([#1187](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1187))

### Added
- `opentelemetry-instrumentation-redis` add support to instrument RedisCluster clients
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from logging import getLogger
from typing import Any, Type, TypeVar
from urllib.parse import quote as urllib_quote

# pylint: disable=no-name-in-module
from django import conf, get_version
from django.db import connection
from django.db import connections
from django.db.backends.utils import CursorDebugWrapper

from opentelemetry.instrumentation.utils import (
_generate_sql_comment,
_get_opentelemetry_values,
)
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
Expand All @@ -44,7 +48,13 @@ def __init__(self, get_response) -> None:
self.get_response = get_response

def __call__(self, request) -> Any:
with connection.execute_wrapper(_QueryWrapper(request)):
with ExitStack() as stack:
for db_alias in connections:
stack.enter_context(
connections[db_alias].execute_wrapper(
_QueryWrapper(request)
)
)
return self.get_response(request)


Expand Down Expand Up @@ -105,49 +115,7 @@ def __call__(self, execute: Type[T], sql, params, many, context) -> T:
sql += sql_comment

# Add the query to the query log if debugging.
if context["cursor"].__class__ is CursorDebugWrapper:
if isinstance(context["cursor"], CursorDebugWrapper):
context["connection"].queries_log.append(sql)

return execute(sql, params, many, context)


def _generate_sql_comment(**meta) -> str:
"""
Return a SQL comment with comma delimited key=value pairs created from
**meta kwargs.
"""
key_value_delimiter = ","

if not meta: # No entries added.
return ""

# Sort the keywords to ensure that caching works and that testing is
# deterministic. It eases visual inspection as well.
return (
" /*"
+ key_value_delimiter.join(
f"{_url_quote(key)}={_url_quote(value)!r}"
for key, value in sorted(meta.items())
if value is not None
)
+ "*/"
)


def _url_quote(value) -> str:
if not isinstance(value, (str, bytes)):
return value
_quoted = urllib_quote(value)
# Since SQL uses '%' as a keyword, '%' is a by-product of url quoting
# e.g. foo,bar --> foo%2Cbar
# thus in our quoting, we need to escape it too to finally give
# foo,bar --> foo%%2Cbar
return _quoted.replace("%", "%%")


def _get_opentelemetry_values() -> dict or None:
"""
Return the OpenTelemetry Trace and Span IDs if Span ID is set in the
OpenTelemetry execution context.
"""
return _propagator.inject({})
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@
class TestMiddleware(WsgiTestBase):
@classmethod
def setUpClass(cls):
conf.settings.configure(ROOT_URLCONF=modules[__name__])
conf.settings.configure(
ROOT_URLCONF=modules[__name__],
DATABASES={
"default": {},
"other": {},
}, # db.connections gets populated only at first test execution
)
super().setUpClass()

def setUp(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.

# pylint: disable=no-name-in-module

from unittest.mock import MagicMock, patch

import pytest
from django import VERSION, conf
from django.http import HttpResponse
from django.test.utils import setup_test_environment, teardown_test_environment

from opentelemetry.instrumentation.django import DjangoInstrumentor
from opentelemetry.instrumentation.django.middleware.sqlcommenter_middleware import (
SqlCommenter,
_QueryWrapper,
)
from opentelemetry.test.wsgitestutil import WsgiTestBase
Expand Down Expand Up @@ -98,3 +99,19 @@ def test_query_wrapper(self, trace_capture):
"Select 1 /*app_name='app',controller='view',route='route',traceparent='%%2Atraceparent%%3D%%2700-0000000"
"00000000000000000deadbeef-000000000000beef-00'*/",
)

@patch(
"opentelemetry.instrumentation.django.middleware.sqlcommenter_middleware._QueryWrapper"
)
def test_multiple_connection_support(self, query_wrapper):
if not DJANGO_2_0:
pytest.skip()

requests_mock = MagicMock()
get_response = MagicMock()

sql_instance = SqlCommenter(get_response)
sql_instance(requests_mock)

# check if query_wrapper is added to the context for 2 databases
self.assertEqual(query_wrapper.call_count, 2)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY # noqa: F401
from opentelemetry.propagate import extract
from opentelemetry.trace import Span, StatusCode
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)

propagator = TraceContextTextMapPropagator()


def extract_attributes_from_object(
Expand Down Expand Up @@ -119,24 +124,22 @@ def _start_internal_or_server_span(
return span, token


_KEY_VALUE_DELIMITER = ","


def _generate_sql_comment(**meta):
def _generate_sql_comment(**meta) -> str:
"""
Return a SQL comment with comma delimited key=value pairs created from
**meta kwargs.
"""
key_value_delimiter = ","

if not meta: # No entries added.
return ""

# Sort the keywords to ensure that caching works and that testing is
# deterministic. It eases visual inspection as well.
# pylint: disable=consider-using-f-string
return (
" /*"
+ _KEY_VALUE_DELIMITER.join(
"{}={!r}".format(_url_quote(key), _url_quote(value))
+ key_value_delimiter.join(
f"{_url_quote(key)}={_url_quote(value)!r}"
for key, value in sorted(meta.items())
if value is not None
)
Expand All @@ -155,6 +158,17 @@ def _url_quote(s): # pylint: disable=invalid-name
return quoted.replace("%", "%%")


def _get_opentelemetry_values():
"""
Return the OpenTelemetry Trace and Span IDs if Span ID is set in the
OpenTelemetry execution context.
"""
# Insert the W3C TraceContext generated
_headers = {}
propagator.inject(_headers)
return _headers


def _generate_opentelemetry_traceparent(span: Span) -> str:
meta = {}
_version = "00"
Expand Down