Skip to content

Commit

Permalink
Fix AttributeError: __aexit__ for aiopg.connect and aiopg.create_pool (
Browse files Browse the repository at this point in the history
  • Loading branch information
srikanthccv authored Jan 8, 2021
1 parent cb01a6b commit 57b8106
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#259](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/259))
- `opentelemetry-exporter-datadog` Fix unintentional type change of span trace flags
([#261](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/261))
- `opentelemetry-instrumentation-aiopg` Fix AttributeError `__aexit__` when `aiopg.connect` and `aio[g].create_pool` used with async context manager
([#235](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/235))

## [0.16b1](https://github.com/open-telemetry/opentelemetry-python-contrib/releases/tag/v0.16b1) - 2020-11-26

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import aiopg
import wrapt
from aiopg.utils import _ContextManager, _PoolContextManager

from opentelemetry.instrumentation.aiopg.aiopg_integration import (
AiopgIntegration,
Expand Down Expand Up @@ -99,7 +100,7 @@ def wrap_connect(
"""

# pylint: disable=unused-argument
async def wrap_connect_(
def wrap_connect_(
wrapped: typing.Callable[..., typing.Any],
instance: typing.Any,
args: typing.Tuple[typing.Any, typing.Any],
Expand All @@ -113,7 +114,9 @@ async def wrap_connect_(
version=version,
tracer_provider=tracer_provider,
)
return await db_integration.wrapped_connection(wrapped, args, kwargs)
return _ContextManager(
db_integration.wrapped_connection(wrapped, args, kwargs)
)

try:
wrapt.wrap_function_wrapper(aiopg, "connect", wrap_connect_)
Expand Down Expand Up @@ -191,7 +194,7 @@ def wrap_create_pool(
tracer_provider: typing.Optional[TracerProvider] = None,
):
# pylint: disable=unused-argument
async def wrap_create_pool_(
def wrap_create_pool_(
wrapped: typing.Callable[..., typing.Any],
instance: typing.Any,
args: typing.Tuple[typing.Any, typing.Any],
Expand All @@ -205,7 +208,9 @@ async def wrap_create_pool_(
version=version,
tracer_provider=tracer_provider,
)
return await db_integration.wrapped_pool(wrapped, args, kwargs)
return _PoolContextManager(
db_integration.wrapped_pool(wrapped, args, kwargs)
)

try:
wrapt.wrap_function_wrapper(aiopg, "create_pool", wrap_create_pool_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ def test_instrumentor_connect(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_instrumentor_connect_ctx_manager(self):
async def _ctx_manager_connect():
AiopgInstrumentor().instrument()

async with aiopg.connect(database="test") as cnx:
async with cnx.cursor() as cursor:
query = "SELECT * FROM test"
await cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(
span, opentelemetry.instrumentation.aiopg
)

async_call(_ctx_manager_connect())

def test_instrumentor_create_pool(self):
AiopgInstrumentor().instrument()

Expand Down Expand Up @@ -110,6 +130,27 @@ def test_instrumentor_create_pool(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_instrumentor_create_pool_ctx_manager(self):
async def _ctx_manager_pool():
AiopgInstrumentor().instrument()

async with aiopg.create_pool(database="test") as pool:
async with pool.acquire() as cnx:
async with cnx.cursor() as cursor:
query = "SELECT * FROM test"
await cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(
span, opentelemetry.instrumentation.aiopg
)

async_call(_ctx_manager_pool())

def test_custom_tracer_provider_connect(self):
resource = resources.Resource.create({})
result = self.create_tracer_provider(resource=resource)
Expand Down Expand Up @@ -428,6 +469,12 @@ async def _acquire(self):
)
return connect

def close(self):
pass

async def wait_closed(self):
pass


class MockPsycopg2Connection:
def __init__(self, database, server_port, server_host, user):
Expand Down Expand Up @@ -471,6 +518,9 @@ async def callproc(self, query, params=None, throw_exception=False):
if throw_exception:
raise Exception("Test Exception")

def close(self):
pass


class AiopgConnectionMock:
_conn = MagicMock()
Expand Down

0 comments on commit 57b8106

Please sign in to comment.